CUTLASS 1.2
This commit is contained in:
12
CHANGELOG.md
12
CHANGELOG.md
@ -1,9 +1,13 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [1.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.0) (2018-10-26)
|
||||
* Parallelized reductions across threadblocks ("Split-K")
|
||||
* Improved IGEMM performance
|
||||
* Batched strided WMMA GEMMs
|
||||
|
||||
## 1.1.0 (2018-09-19)
|
||||
## [1.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.1.0) (2018-09-19)
|
||||
* Turing Features
|
||||
* WMMA GEMM targeting TensorCores - INT8, INT4, 1-bit
|
||||
* WMMA GEMM targeting TensorCores - INT8, INT4, INT1
|
||||
* Batched Strided GEMM
|
||||
* Threadblock rasterization strategies
|
||||
* Improved performance for adverse problem sizes and data layouts
|
||||
@ -16,13 +20,13 @@
|
||||
* Examples
|
||||
* Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM
|
||||
|
||||
## 1.0.1 (2018-06-11)
|
||||
## [1.0.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.0.1) (2018-06-11)
|
||||
|
||||
* Intra-threadblock reduction added for small threadblock tile sizes
|
||||
* sgemm_64x128x16, sgemm_128x128x16, sgemm_128x64x16, sgemm_128x32x16, sgemm_64x64x16, sgemm_64x32x16
|
||||
* igemm_32x32x128
|
||||
* GEMM _K_ residue handled during prologue prior to mainloop
|
||||
* Replaced Google Test copy with submodule. Use `git submodule init`
|
||||
* Replaced Google Test copy with submodule. Use `git submodule init --recursive --update`
|
||||
|
||||
## [1.0.0](https://github.com/NVIDIA/cutlass/commit/2028ebe120aab22bfd0b2baf8902d4c9627eb33f) (2018-05-16)
|
||||
|
||||
|
||||
@ -141,6 +141,10 @@ else()
|
||||
string(APPEND NVCC_FLAGS " -lineinfo")
|
||||
endif()
|
||||
|
||||
if (UNIX)
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
|
||||
endif()
|
||||
|
||||
string(APPEND NVCC_FLAGS_DEBUG " -g")
|
||||
string(APPEND NVCC_FLAGS_RELWITHDEBINFO " -O3")
|
||||
string(APPEND NVCC_FLAGS_RELEASE " -O3")
|
||||
@ -169,6 +173,8 @@ file(GLOB CUTLASS_GEMM RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/gemm/*.h)
|
||||
file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
|
||||
file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
|
||||
file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
|
||||
file(GLOB CUTLASS_REDUCTION RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/reduction/*.h )
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Define build targets
|
||||
@ -178,6 +184,7 @@ file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
|
||||
source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
|
||||
source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
|
||||
source_group("cutlass\\device" FILES ${CUTLASS_DEVICE})
|
||||
source_group("cutlass\\reduction" FILES ${CUTLASS_REDUCTION})
|
||||
source_group("cutlass" FILES ${CUTLASS_CORE})
|
||||
|
||||
add_library(CUTLASS INTERFACE)
|
||||
@ -187,6 +194,7 @@ target_sources(CUTLASS INTERFACE
|
||||
${CUTLASS_UTIL}
|
||||
${CUTLASS_DEVICE}
|
||||
${CUTLASS_CORE}
|
||||
${CUTLASS_REDUCTION}
|
||||
)
|
||||
|
||||
target_include_directories(CUTLASS INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
@ -197,6 +205,7 @@ add_custom_target(cutlass_ide SOURCES
|
||||
${CUTLASS_UTIL}
|
||||
${CUTLASS_DEVICE}
|
||||
${CUTLASS_CORE}
|
||||
${CUTLASS_REDUCTION}
|
||||
)
|
||||
# Doxygen is available. Generate documentation
|
||||
if (DOXYGEN_FOUND)
|
||||
|
||||
64
CUTLASS.md
64
CUTLASS.md
@ -9,6 +9,7 @@ CUTLASS core components, and to identify their role in implementing GEMM computa
|
||||
2. [General Matrix Multiply](#S-general-matrix-multiply)
|
||||
3. [Core Components](#S-core-components)
|
||||
4. [Utilities](#S-utilities)
|
||||
5. [Optimization Strategies](#S-optimization-strategies)
|
||||
|
||||
# <a name="S-design-patterns"></a> 1. Design Patterns
|
||||
|
||||
@ -26,7 +27,7 @@ objectives. This section is intended to provide more detail.
|
||||
|
||||
## <a name="S-patterns-sequencing-nesting"></a> Sequencing and Nesting of Collective Primitives
|
||||
|
||||
CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks.
|
||||
CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks.
|
||||
|
||||
## <a name="S-patterns-tiles-iterators"></a> Tiles and Iterators
|
||||
|
||||
@ -48,7 +49,7 @@ CUTLASS can take advantage of this CUDA grid-invariant property by constructing
|
||||
|
||||
The design pattern in CUTLASS is for classes with nontrivial constructors to define `struct Params` as an inner class which contains grid-invariant state. These should define a constructor and an `initialize()` method. The `Params` structure should also include a data member corresponding to each data member in the parent class, so these too can be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as its first argument.
|
||||
|
||||
For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`.
|
||||
For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`.
|
||||
|
||||
|
||||
## <a name="S-patterns-composable-shared-memory"></a> Composable shared memory allocation
|
||||
@ -94,7 +95,7 @@ multiply operation performed by each iteration of the mainloop is referred to as
|
||||
|
||||
The threadblock loads a sequence of tiles from global memory and stores this data to shared memory. The iterative
|
||||
access and traversal of tiles in global memory are performed by a _TileLoadIterator_, and storing to a circular
|
||||
buffer in shared memory is performed by a _GlobalLoadIterator_.
|
||||
buffer in shared memory is performed by a _GlobalLoadIterator_.
|
||||
|
||||
**[Global Load Stream](cutlass/gemm/gemm_global_stream.h)** manages loading of the threadblock-scope multiplicands to the GEMM kernel. It owns an iterator into global memory for loading tiles of data, a TensorAllocation in shared memory to hold the resulting tile, and an iterator for writing the tile into this allocation. A transformer exists to optionally transform the data as it is loaded which may of use to perform type conversion or, in the case of int8 GEMM, transpose 4x4 tiles held in registers.
|
||||
|
||||
@ -109,24 +110,24 @@ The Global Load Stream template contains members defined by the following templa
|
||||
The threadblock's _OutputTile_ is partitioned among the warps, and each computes a warp-level matrix product.
|
||||
Data is loaded from shared memory into registers, and math instructions are dispatched to CUDA Cores or Tensor Cores.
|
||||
|
||||
[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results.
|
||||
[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results.
|
||||
|
||||
* [GemmSharedLoadTile{A,B}](cutlass/gemm/gemm_shared_tile.h)
|
||||
|
||||
**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores.
|
||||
**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores.
|
||||
|
||||
* [WMMA Multiply Add](cutlass/gemm/wmma_gemm_multiply_add.h)
|
||||
|
||||
## Thread-level GEMM
|
||||
|
||||
SGEMM, IGEMM, HGEMM, and DGEMM are computed by SIMT math instructions issued by thread-level matrix multiply
|
||||
procedures.
|
||||
procedures.
|
||||
|
||||
* [ThreadMultiplyAdd](cutlass/gemm/thread_multiply_add.h)
|
||||
* [IGEMM specialization](cutlass/gemm/igemm_multiply_add.h)
|
||||
* [HGEMM specialization](cutlass/gemm/hgemm_multiply_add.h)
|
||||
|
||||
## Epilogue
|
||||
## Epilogue
|
||||
|
||||
The [**epilogue**](cutlass/gemm/gemm_epilogue.h) iteratively selects a subset of accumulator elements held by a warp, writes them to shared memory, and loads them by different threads such that a threadblock-scoped tile store operation will make contiguous, striped accesses to global memory. Thus, the flow of data utilizes the following components:
|
||||
|
||||
@ -227,7 +228,7 @@ must specify compile-time constant tile sizes.
|
||||
## <a name="S-core-tile-structure"></a> Tile Structure
|
||||
|
||||
Tiled structures express an arrangement of data in memory as well as a logical mapping of concurrent CUDA
|
||||
threads to the problem space. For example, the CUTLASS GEMM
|
||||
threads to the problem space. For example, the CUTLASS GEMM
|
||||
|
||||
Tiled structures can be defined using the `cutlass::TileTraits<>` concept which defines the following
|
||||
members. Collectively, these members offer a flexible way to define a 4-D subpartition of an integer
|
||||
@ -286,7 +287,7 @@ the next item in sequence.
|
||||
<img src="/media/images/cutlass-tile-iteration.png" alt="CUTLASS tile access and traversal" width="50%" />
|
||||
|
||||
To offer a generic solution that spans numerous data types and layouts, CUTLASS defines the _TileIterator_ concept.
|
||||
This concept provides access to a sequence of _tiles_ embedded in a tensor in addressable memory.
|
||||
This concept provides access to a sequence of _tiles_ embedded in a tensor in addressable memory.
|
||||
|
||||
The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h).
|
||||
|
||||
@ -296,9 +297,9 @@ A fragment is analogous to `std::array<>` in that it is a constant-sized array o
|
||||
|
||||
## <a name="S-core-predicate-vector"></a> Predicate Vector
|
||||
|
||||
SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions.
|
||||
SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions.
|
||||
|
||||
CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking.
|
||||
CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking.
|
||||
|
||||
|
||||
# <a name="S-utilities"></a> 4. Utilities
|
||||
@ -310,6 +311,46 @@ framework offering features such as:
|
||||
* Components for allocating and initializing [host-side and device-side tensors](tools/util/host_tensor.h) usable by CUTLASS
|
||||
* Reference implementations of [GEMM](tools/util/reference/host/gemm.h) and [element-wise operations](tools/util/reference/host/tensor_elementwise.h)
|
||||
|
||||
|
||||
# <a name="S-optimization-strategies"></a>5. Optimization Strategies
|
||||
|
||||
This section describes several strategies taken to increase performance beyond what is achievable with
|
||||
a basic implementation of the hierarchical GEMM structure.
|
||||
|
||||
|
||||
## Threadblock Rasterization
|
||||
|
||||
To maximize reuse of data held in the last level cache, CUTLASS defines several functions to
|
||||
affect the mapping of threadblocks to logical partitions of the GEMM problem. These map
|
||||
consecutively launched threadblocks to packed two-dimensional regions of the partitioned GEMM
|
||||
problem to increase the probability that these will access the same tiles of global memory at
|
||||
approximately the same time.
|
||||
|
||||
Several functions are defined in [cutlass/gemm/threadblock_swizzle.h](cutlass/gemm/threadblock_swizzle.h).
|
||||
|
||||
|
||||
## Parallel Reductions across GEMM _K_
|
||||
|
||||
Matrix product computations expose parallelism among _O(MN)_ independent inner product
|
||||
computations. For sufficiently large problem sizes, a GEMM kernel in CUTLASS may approach
|
||||
the theoretical maximum computational throughput. For small problems, however, there are
|
||||
too few threadblocks to efficiently occupy the entire GPU.
|
||||
|
||||
As a recourse, parallelizing the reduction performed during the inner product computation
|
||||
enables more threadblocks to execute concurrently while still taking advantage of the throughput
|
||||
benefits of large threadblock-level GEMM tiles.
|
||||
|
||||
CUTLASS implements parallel reductions across threadblocks by partitioning the GEMM _K_ dimension
|
||||
and launching an additional set of threadblocks for each partition. Consequently, we refer to
|
||||
this strategy within CUTLASS as "parallel reduction splitK." The "parallel reduction splitK" in cutlass requires the execution of 2 kernels. The first one is called partitionedK GEMM. The second one is called batched reduction.
|
||||
|
||||
The partitionedK GEMM is very similar to one flavor of batched strided GEMM. Instead of requiring users to specify the problem size of each batch, partitionedK GEMM asks for the overall problem size and the number of partition that will be applied along K dimension for operand A and B. For example, parameters of m=128, n=128, k=4096 and partition=16 will result in 16 batched strided GEMMs with each batch of m=128, n=128, k=256. PartitionedK also allows scenario where k is not divisible by partition count. For example, parameters of m=128, n=128, k=4096 and partition=20 will result in 20 batched strided GEMMs with the first 19 batches of m=128, n=128, k=4096/20=204 and the last batch of m=128, n=128, k=220.
|
||||
|
||||
The batched reduction kernel will further perform reduction along the K-dimension. Thus, the input of the batched reduction kernel is the output (C) of partitionedK GEMM. An workspace memory is managed by the users to store this intermediate results.
|
||||
|
||||
An example of splitK usage can be found [here](examples/06_splitK_gemm/splitK_gemm.cu).
|
||||
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
@ -335,4 +376,3 @@ Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
15
README.md
15
README.md
@ -1,10 +1,10 @@
|
||||

|
||||
|
||||
# CUTLASS 1.1
|
||||
# CUTLASS 1.2
|
||||
|
||||
_CUTLASS 1.1.0 - September 2018_
|
||||
_CUTLASS 1.2.0 - October 2018_
|
||||
|
||||
CUTLASS 1.1 is a collection of CUDA C++ template abstractions for implementing
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical decomposition and data movement similar
|
||||
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
|
||||
@ -22,12 +22,19 @@ point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targe
|
||||
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
|
||||
and beyond.
|
||||
|
||||
CUTLASS 1.1 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
|
||||
CUTLASS 1.2 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
We describe the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
# What's New in CUTLASS 1.2
|
||||
_October 2018_
|
||||
* [Parallelized Reductions](CUTLASS.md#parallel-reductions-across-gemm-k)
|
||||
* Batched strided WMMA GEMM
|
||||
|
||||
|
||||
# What's New in CUTLASS 1.1
|
||||
_September 2018_
|
||||
|
||||
* [CUTLASS Documentation](CUTLASS.md)
|
||||
* [Examples](examples/)
|
||||
|
||||
@ -313,6 +313,56 @@ struct Coord {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank, Index> operator*(T s, Coord<Rank, Index> coord) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
coord[i] *= s;
|
||||
}
|
||||
return coord;
|
||||
}
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T, int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank, Index> operator*(Coord<Rank, Index> coord, T s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
coord[i] *= s;
|
||||
}
|
||||
return coord;
|
||||
}
|
||||
|
||||
/// Scalar division
|
||||
template <typename T, int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank, Index> operator/(T s, Coord<Rank, Index> coord) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
coord[i] = s / coord[i];
|
||||
}
|
||||
return coord;
|
||||
}
|
||||
|
||||
/// Scalar division
|
||||
template <typename T, int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank, Index> operator/(Coord<Rank, Index> coord, T s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
coord[i] /= s;
|
||||
}
|
||||
return coord;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Integer-valued make_Coord
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to make a 2-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<1> make_Coord(int _0) {
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUTLASS_MAJOR 1
|
||||
#define CUTLASS_MINOR 1
|
||||
#define CUTLASS_MINOR 2
|
||||
#define CUTLASS_PATCH 0
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
@ -49,21 +49,7 @@
|
||||
|
||||
#define CUTLASS_ASSERT(x) assert(x)
|
||||
|
||||
// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if defined(_MSC_VER)
|
||||
#define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL __pragma("unroll 1")
|
||||
#else
|
||||
#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
|
||||
#endif
|
||||
#else
|
||||
#define CUTLASS_PRAGMA_UNROLL
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL
|
||||
#endif
|
||||
|
||||
#define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
|
||||
#include "cutlass/util/performance_tuning.h"
|
||||
|
||||
// A small helper class to dump a type at compile time
|
||||
// Usage:: DumpType<Class>::Class
|
||||
|
||||
67
cutlass/gemm/device_gemm.h
Normal file
67
cutlass/gemm/device_gemm.h
Normal file
@ -0,0 +1,67 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief device level GEMM implemented by more than one kernels.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
template<typename DeviceGemmTraits_ >
|
||||
struct DeviceGemm {
|
||||
/// The Traits
|
||||
typedef DeviceGemmTraits_ Traits;
|
||||
/// Use the params object defined in traits
|
||||
typedef typename Traits::Params Params;
|
||||
|
||||
/// Support for NVRTC
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Launch the kernels in order
|
||||
static __host__ cudaError_t launch(Params const& params) {
|
||||
Traits::GemmTraits::KernelClass::launch(params.GemmParams);
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
return err;
|
||||
Traits::ReductionTraits::KernelClass::launch(params.ReductionParams);
|
||||
return cudaGetLastError();
|
||||
}
|
||||
#endif
|
||||
|
||||
///
|
||||
/// Methods
|
||||
///
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE DeviceGemm() {}
|
||||
};
|
||||
} // namespace device_gemm
|
||||
} // namespace cutalss
|
||||
170
cutlass/gemm/device_gemm_traits.h
Normal file
170
cutlass/gemm/device_gemm_traits.h
Normal file
@ -0,0 +1,170 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
#include <assert.h>
|
||||
#include "cutlass/gemm/device_gemm.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/gemm/gemm_desc.h"
|
||||
#include "tools/util/type_traits.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
template <
|
||||
/// The Tratis for the first kernel
|
||||
typename GemmTraits_,
|
||||
/// The Traits for the second kernel
|
||||
typename ReductionTraits_
|
||||
>
|
||||
struct SplitkPIGemmTraits {
|
||||
typedef GemmTraits_ GemmTraits;
|
||||
typedef ReductionTraits_ ReductionTraits;
|
||||
typedef SplitkPIGemmTraits<GemmTraits_, ReductionTraits_> This_;
|
||||
typedef typename cutlass::gemm::DeviceGemm<This_> KernelClass;
|
||||
|
||||
///
|
||||
typedef typename GemmTraits::Index Index;
|
||||
///
|
||||
typedef typename ReductionTraits::ScalarAlphaBeta Scalar;
|
||||
///
|
||||
typedef typename GemmTraits::ScalarA ScalarA;
|
||||
///
|
||||
typedef typename GemmTraits::ScalarB ScalarB;
|
||||
///
|
||||
typedef typename GemmTraits::ScalarD ScalarAccum;
|
||||
///
|
||||
typedef typename ReductionTraits::ScalarC ScalarC;
|
||||
///
|
||||
typedef typename ReductionTraits::ScalarD ScalarD;
|
||||
/// The layout of A. can be deduced from the layout set in batched gemm
|
||||
static MatrixLayout::Kind const kLayoutA = GemmTraits::kLayoutA;
|
||||
/// The layout of B. can be deduced from the layout set in batched gemm
|
||||
static MatrixLayout::Kind const kLayoutB = GemmTraits::kLayoutB;
|
||||
|
||||
struct Params {
|
||||
/// The dimensions of the GEMM in K, N, M order
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// Check if params are init
|
||||
bool problem_size_initialized;
|
||||
/// The pointer to workspace memory
|
||||
ScalarAccum *workspace_ptr;
|
||||
///
|
||||
int workspace_size;
|
||||
/// The Params for the first kernel
|
||||
typename GemmTraits::Params GemmParams;
|
||||
/// The Params for the second kernel
|
||||
typename ReductionTraits::Params ReductionParams;
|
||||
|
||||
/// ctor
|
||||
Params() :
|
||||
workspace_size(0),
|
||||
problem_size_initialized(false) {}
|
||||
/// ctor
|
||||
Params(Index m_,
|
||||
Index n_,
|
||||
Index k_
|
||||
):
|
||||
problem_size(k_, n_, m_, 1),
|
||||
workspace_size(0),
|
||||
problem_size_initialized(true) {
|
||||
|
||||
}
|
||||
|
||||
/// init problem is needed if using default ctor
|
||||
void init_problem(Index m_,
|
||||
Index n_,
|
||||
Index k_){
|
||||
problem_size = GemmCoord(k_, n_, m_, 1);
|
||||
problem_size_initialized = true;
|
||||
}
|
||||
|
||||
int initialize(Scalar alpha_,
|
||||
ScalarA const* d_a_,
|
||||
Index lda_,
|
||||
ScalarB const* d_b_,
|
||||
Index ldb_,
|
||||
Scalar beta_,
|
||||
ScalarC const* d_c_,
|
||||
Index ldc_,
|
||||
ScalarD* d_d_,
|
||||
Index ldd_,
|
||||
ScalarAccum *workspace_ptr_) {
|
||||
|
||||
workspace_ptr = workspace_ptr_;
|
||||
|
||||
//call GemmTraits (first kernel) param
|
||||
//for the first kernel A is A, B is B, C and D are workspace
|
||||
//alpha is one, beta is zero, partitionK_count is reductionTraits::reductionSize
|
||||
typename cutlass::gemm::GemmDesc<typename GemmTraits::ScalarA,
|
||||
typename GemmTraits::ScalarB,
|
||||
typename GemmTraits::ScalarC,
|
||||
typename GemmTraits::ScalarD,
|
||||
typename GemmTraits::Epilogue::Scalar>
|
||||
desc(
|
||||
problem_size,
|
||||
typename cutlass::TypeTraits<typename GemmTraits::Epilogue::Scalar>::host_type(1.0f), /*alpha*/
|
||||
TensorRef<typename GemmTraits::ScalarA const, 2>(d_a_, lda_),
|
||||
TensorRef<typename GemmTraits::ScalarB const, 2>(d_b_, ldb_),
|
||||
typename cutlass::TypeTraits<typename GemmTraits::Epilogue::Scalar>::host_type(0.0f), /*beta*/
|
||||
TensorRef<typename GemmTraits::ScalarC const, 2>(workspace_ptr, problem_size.m()), /*m = ldc, workspace is not transposed and is packed*/
|
||||
TensorRef<typename GemmTraits::ScalarD, 2>(workspace_ptr, problem_size.m()) /*m = ldd, workspace is not transposed and is packed*/
|
||||
);
|
||||
GemmParams.initialize(desc, ReductionTraits::ReductionSize);
|
||||
|
||||
|
||||
//call batched reduction (second kernel) param
|
||||
ReductionParams.initialize(problem_size.m(), /*m*/
|
||||
problem_size.n(), /*n*/
|
||||
alpha_, /*alpha*/
|
||||
beta_, /*beta*/
|
||||
problem_size.n() * problem_size.m() /*reduction_stride*/,
|
||||
workspace_ptr,
|
||||
problem_size.m(),
|
||||
d_c_,
|
||||
ldc_,
|
||||
d_d_,
|
||||
ldd_);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// workspace will be used to store D (output) from the first gemm kernel (not D of the entire gemm)
|
||||
// note typedef typename GemmTraits::ScalarD ScalarAccum;
|
||||
// workspace of size of M * N * Reduction
|
||||
int required_workspace_memory_in_byte(){
|
||||
assert(problem_size_initialized == true);
|
||||
workspace_size = problem_size.n() * problem_size.m() * ReductionTraits::ReductionSize * static_cast<int>(sizeof(ScalarAccum));
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
} // namespace device_gemm
|
||||
} // namespace cutalss
|
||||
@ -243,23 +243,27 @@ struct Gemm {
|
||||
// We may want to use shared memory to clear the registers.
|
||||
typedef typename Traits::ClearAccumulators ClearAccumulators;
|
||||
|
||||
// Get the bounds for each thread, it maybe different than problem_size
|
||||
Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
|
||||
params.partitionK_range);
|
||||
|
||||
// The streams to read A/B from global memory to shared memory.
|
||||
typename Traits::GlobalLoadStream global_to_shared_stream(
|
||||
params.global_to_shared_stream,
|
||||
shared_storage.main_loop.global_to_shared_stream,
|
||||
shared_storage.main_loop.threadblock_tile.reference(),
|
||||
params.problem_size.knm(),
|
||||
bounds,
|
||||
threadblock_offset);
|
||||
|
||||
// update A and B pointer offset based on batch_id and batch_stride_offset
|
||||
//global_to_shared_stream.add_pointer_offset(block_swizzle.get_batch_id(), params.batch_stride_A, params.batch_stride_B);
|
||||
global_to_shared_stream += make_Coord(block_swizzle.get_batch_id(), 0, 0);
|
||||
global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
|
||||
|
||||
// Create the accumulator clear.
|
||||
ClearAccumulators clear;
|
||||
|
||||
// Deal with residue in prolog.
|
||||
global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
|
||||
// global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
|
||||
global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
|
||||
|
||||
// Fetch the fragments for A and B from global memory.
|
||||
global_to_shared_stream.copy();
|
||||
@ -271,7 +275,8 @@ struct Gemm {
|
||||
Traits::shared_store_fence(false);
|
||||
|
||||
// Rollback to the beginning of the first tile (if residue exists).
|
||||
global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
|
||||
// global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
|
||||
global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
|
||||
|
||||
// The stream of data from shared memory to fragments.
|
||||
typename Traits::SharedStream shared_load_stream(
|
||||
@ -288,18 +293,17 @@ struct Gemm {
|
||||
clear.clear(accumulators);
|
||||
|
||||
// Initial index
|
||||
Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
|
||||
|
||||
// Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
|
||||
// problem_size[0] might be bigger than bounds[0]
|
||||
Index outer_k = bounds[0] - Traits::OutputTile::kD;
|
||||
// Check if we are computing residue in prolog or not.
|
||||
if (Traits::GemmConfig::kResidueInProlog) {
|
||||
|
||||
// Execute all mainloop iterations but the last one.
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<false, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
|
||||
}
|
||||
|
||||
// Don't load data for the last "residue" portion since we've already computed the residue.
|
||||
@ -307,7 +311,6 @@ struct Gemm {
|
||||
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<false, true>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
|
||||
}
|
||||
} else {
|
||||
// When kResidueSeparate = true, execute all mainloop iterations but the last two without any
|
||||
@ -319,17 +322,14 @@ struct Gemm {
|
||||
for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<false, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Execute remaining tiles with K-residue predicate updates enabled.
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<true, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -127,6 +127,12 @@ struct GemmCoord : public Coord<4, int> {
|
||||
Coord<2> nm() const {
|
||||
return make_Coord(n(), m());
|
||||
}
|
||||
|
||||
/// Obtains a Coord<2> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> mn() const {
|
||||
return make_Coord(m(), n());
|
||||
}
|
||||
|
||||
/// Obtains a Coord<2> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -131,20 +131,19 @@ struct GemmEpilogue {
|
||||
params.iterator_c, problem_size, block, pointer_offset, predicate_offset);
|
||||
|
||||
// update C pointer offset based on batch_id and batch_stride_offset
|
||||
//global_load_iterator.add_pointer_offset(batch_id * params.batch_stride_offset_c);
|
||||
global_load_iterator += make_Coord(batch_id, 0, 0);
|
||||
global_load_iterator.add_pointer_offset(batch_id * params.batch_stride_C);
|
||||
|
||||
// The transformer for C.
|
||||
GlobalTransformerC transformer_c;
|
||||
// The transformer for D.
|
||||
GlobalTransformerD transformer_d;
|
||||
|
||||
// The iterator to store into the D matrix.
|
||||
GlobalStoreIteratorD global_store_iterator(
|
||||
params.iterator_d, problem_size, block, pointer_offset, predicate_offset);
|
||||
|
||||
// update D pointer offset based on batch_id and batch_stride_offset
|
||||
//global_store_iterator.add_pointer_offset(batch_id * params.batch_stride_offset_d);
|
||||
global_store_iterator += make_Coord(batch_id, 0, 0);
|
||||
global_store_iterator.add_pointer_offset(batch_id * params.batch_stride_D);
|
||||
|
||||
SharedStoreTransformerD shared_store_transformer;
|
||||
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
|
||||
@ -171,6 +170,7 @@ struct GemmEpilogue {
|
||||
int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
|
||||
|
||||
shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
|
||||
|
||||
shared_store_iterator.store_post_increment(shared_store_transformed_d);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
@ -182,7 +182,6 @@ struct GemmEpilogue {
|
||||
|
||||
// Do the math.
|
||||
typename GlobalTransformerD::InputFragment fragment_d;
|
||||
|
||||
if (kSourceRequired) {
|
||||
// Transform C fragment.
|
||||
transformer_c.transform(fragment_c, transformed_c);
|
||||
|
||||
@ -97,6 +97,8 @@ struct GemmEpilogueTraits {
|
||||
typedef Functor_ Functor;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The long index
|
||||
typedef long long LongIndex;
|
||||
|
||||
/// We do not support 3D or 4D shapes.
|
||||
static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
|
||||
@ -114,8 +116,16 @@ struct GemmEpilogueTraits {
|
||||
Index stride_h, stride_w;
|
||||
/// The params for the C iterator.
|
||||
typename GlobalLoadIteratorC::Params iterator_c;
|
||||
|
||||
/// Batch stride for C matrix
|
||||
LongIndex batch_stride_C;
|
||||
|
||||
/// The params for the D global iterator.
|
||||
typename GlobalStoreIteratorD::Params iterator_d;
|
||||
|
||||
/// Batch stride for C matrix
|
||||
LongIndex batch_stride_D;
|
||||
|
||||
/// The params for the D shared store iterator.
|
||||
typename SharedStoreIteratorD::Params shared_store_iterator_d;
|
||||
/// The params for the D shared load stream.
|
||||
@ -139,22 +149,29 @@ struct GemmEpilogueTraits {
|
||||
this->stride_w = 0;
|
||||
// Setup the params for the global memory iterator for C.
|
||||
error_code = iterator_c.initialize(desc.C.data(),
|
||||
desc.batch_stride_C,
|
||||
desc.C.leading_dim(),
|
||||
desc.C.leading_dim(),
|
||||
desc.problem_size[1],
|
||||
stride_w,
|
||||
Delta::kW);
|
||||
|
||||
batch_stride_C = desc.batch_stride_C;
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for D.
|
||||
return iterator_d.initialize(desc.D.data(),
|
||||
desc.batch_stride_D,
|
||||
error_code = iterator_d.initialize(desc.D.data(),
|
||||
desc.D.leading_dim(),
|
||||
desc.D.leading_dim(),
|
||||
desc.problem_size[1],
|
||||
stride_w,
|
||||
Delta::kW);
|
||||
|
||||
batch_stride_D = desc.batch_stride_D;
|
||||
|
||||
return error_code;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -80,6 +80,8 @@ struct GlobalLoadStream {
|
||||
typedef typename LoadIterator::Pointer Pointer;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::Index Index;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::LongIndex LongIndex;
|
||||
/// The tile
|
||||
typedef typename LoadIterator::Tile Tile;
|
||||
|
||||
@ -94,24 +96,46 @@ struct GlobalLoadStream {
|
||||
struct Params {
|
||||
// The load iterator.
|
||||
typename LoadIterator::Params load_iterator;
|
||||
|
||||
/// Batch stride in global memory
|
||||
LongIndex batch_stride;
|
||||
|
||||
// The store iterator.
|
||||
typename StoreIterator::Params store_iterator;
|
||||
|
||||
// Offset to residue.
|
||||
Index offset_to_residue;
|
||||
|
||||
// Offset to residue for the last partition
|
||||
Index offset_to_residue_last_partition;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
|
||||
long long batch_stride,
|
||||
LongIndex batch_stride_,
|
||||
Index ldm,
|
||||
Index _offset_to_residue) {
|
||||
Index offset_to_residue_,
|
||||
Index offset_to_residue_last_partition_) {
|
||||
|
||||
offset_to_residue = _offset_to_residue;
|
||||
int error_code = load_iterator.initialize(pointer, batch_stride, ldm);
|
||||
int error_code = load_iterator.initialize(pointer, ldm, ldm);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
batch_stride = batch_stride_;
|
||||
offset_to_residue = offset_to_residue_;
|
||||
offset_to_residue_last_partition = offset_to_residue_last_partition_;
|
||||
|
||||
return store_iterator.initialize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE Index get_offset_to_residue() {
|
||||
if (blockIdx.z == gridDim.z - 1) { //last partition
|
||||
return offset_to_residue_last_partition;
|
||||
}
|
||||
else {
|
||||
return offset_to_residue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Contains private storage in shared memory needed by the objects within this class. Note,
|
||||
@ -124,7 +148,7 @@ struct GlobalLoadStream {
|
||||
//
|
||||
|
||||
/// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory
|
||||
CUTLASS_DEVICE static Coord<3> project_coordinate(Coord<3> const& coord, Index d_offset = 0) {
|
||||
CUTLASS_HOST_DEVICE static Coord<3> project_coordinate(Coord<3> const& coord, Index d_offset = 0) {
|
||||
bool const kKstrided =
|
||||
GemmMultiplicandTraits<typename LoadIterator::Tile, kOperand, kLayout>::kKstrided;
|
||||
Coord<3> tile_coord = ProjectOperand<kOperand, kKstrided>::project(coord);
|
||||
@ -140,21 +164,20 @@ struct GlobalLoadStream {
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const& _threadblock_offset)
|
||||
: params(_params),
|
||||
multiplicand_bounds(project_coordinate(bounds, 1)),
|
||||
threadblock_offset(project_coordinate(_threadblock_offset)),
|
||||
load_iterator(params.load_iterator,
|
||||
project_coordinate(bounds, 1), /*multiplicant_bounds*/
|
||||
project_coordinate(_threadblock_offset) /*threablock_offset*/),
|
||||
multiplicand_bounds(project_coordinate(bounds, 1)),
|
||||
load_iterator(params.load_iterator, threadblock_offset),
|
||||
transformer(),
|
||||
store_iterator(params.store_iterator, threadblock_tile_ref.data())
|
||||
{
|
||||
store_iterator(params.store_iterator, threadblock_tile_ref.data()) {
|
||||
load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy() { load_iterator.load_post_increment(fetched_fragment); }
|
||||
CUTLASS_DEVICE void copy() {
|
||||
load_iterator.load_post_increment(fetched_fragment);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
@ -176,8 +199,9 @@ struct GlobalLoadStream {
|
||||
Index kResidue = k % kTileK;
|
||||
if (kResidue) {
|
||||
residue(kResidue);
|
||||
Index this_offset_residue = params.get_offset_to_residue();
|
||||
load_iterator.add_pointer_offset(this_offset_residue * load_iterator.stride_advance());
|
||||
}
|
||||
load_iterator.add_pointer_offset(params.offset_to_residue * load_iterator.stride_advance());
|
||||
}
|
||||
|
||||
/// Rollback to the beginning of the first tile
|
||||
@ -187,9 +211,9 @@ struct GlobalLoadStream {
|
||||
int const kBlock = kOperand == GemmOperand::kA
|
||||
? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
|
||||
: (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
|
||||
|
||||
load_iterator.add_pointer_offset(-(params.offset_to_residue + kBlock) *
|
||||
load_iterator.stride_advance());
|
||||
Index this_offset_residue = params.get_offset_to_residue();
|
||||
load_iterator.add_pointer_offset(-(this_offset_residue + kBlock) *
|
||||
load_iterator.stride_advance());
|
||||
}
|
||||
|
||||
/// Adds a Coord<3> to the underlying global load iterator
|
||||
@ -198,16 +222,22 @@ struct GlobalLoadStream {
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Adds an offset based on batch stride
|
||||
CUTLASS_DEVICE GlobalLoadStream &add_batch_offset(int batch_id) {
|
||||
load_iterator.add_pointer_offset(batch_id * params.batch_stride);
|
||||
return *this;
|
||||
}
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters
|
||||
Params params;
|
||||
/// Multiplicand bounds
|
||||
Coord<3> multiplicand_bounds;
|
||||
/// Threadblock offset
|
||||
Coord<3> threadblock_offset;
|
||||
/// Multiplicand bounds
|
||||
Coord<3> multiplicand_bounds;
|
||||
/// The iterator.
|
||||
LoadIterator load_iterator;
|
||||
/// The fragment to fetch from shared memory.
|
||||
|
||||
@ -188,6 +188,8 @@ struct GemmGlobalIteratorAb
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// Long index
|
||||
typedef long long LongIndex;
|
||||
/// The thread offset
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
/// Specifies in which dimension post-increment accesses advance.
|
||||
@ -201,35 +203,9 @@ struct GemmGlobalIteratorAb
|
||||
struct Params : public BaseParams {
|
||||
/// Initializes params to load a strip-mined tile, given pointer and stride_h.
|
||||
CUTLASS_HOST_DEVICE int initialize(Scalar const* ptr,
|
||||
long long stride_d,
|
||||
Index stride_d,
|
||||
Index stride_h) {
|
||||
Index inc_d = 0;
|
||||
Index inc_advance = 0;
|
||||
// Move by some columns for each iteration in the H dimension.
|
||||
Index inc_h = Base::Delta::kH * stride_h;
|
||||
|
||||
// Move by some more columns in the number of iterations if the D dimension is > 1.
|
||||
if (Base::Delta::kD > 0) {
|
||||
inc_d = Base::Delta::kD * stride_h - (Base::Iterations::kH - 1) * inc_h;
|
||||
}
|
||||
|
||||
// Move to the beginning of the next iteration.
|
||||
if (kAdvance == IteratorAdvance::kH && Base::Delta::kD > 0) {
|
||||
inc_advance = inc_d;
|
||||
} else if (kAdvance == IteratorAdvance::kH) {
|
||||
inc_advance = inc_h;
|
||||
} else if (Base::Delta::kD > 0) {
|
||||
inc_advance = (Base::Iterations::kW + 0) * ShapeCount<typename Base::Delta>::kWc -
|
||||
(Base::Iterations::kH - 1) * inc_h -
|
||||
(Base::Iterations::kD - 1) * Base::Delta::kD * stride_h;
|
||||
} else {
|
||||
inc_advance = (Base::Iterations::kW + 0) * ShapeCount<typename Base::Delta>::kWc -
|
||||
(Base::Iterations::kH - 1) * inc_h;
|
||||
}
|
||||
|
||||
Base::Params::initialize(
|
||||
ptr, stride_d, stride_h, 1, inc_d, inc_h, 0, inc_advance);
|
||||
return 0;
|
||||
return BaseParams::initialize(ptr, stride_d, stride_h, kAdvance == IteratorAdvance::kH ? 0 : 1);
|
||||
}
|
||||
};
|
||||
|
||||
@ -268,7 +244,6 @@ struct GemmGlobalIteratorAb
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& threadblock_offset,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
@ -304,11 +279,6 @@ struct GemmGlobalIteratorAb
|
||||
|
||||
/// That's the residue! Update the predicates.
|
||||
CUTLASS_HOST_DEVICE void residue(Index k) {
|
||||
// The coordinates of the thread.
|
||||
Index block_h = thread_offset[1];
|
||||
// The contiguous dimension.
|
||||
Index block_w = thread_offset[2];
|
||||
|
||||
// Update the predicate vector.
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
@ -316,9 +286,9 @@ struct GemmGlobalIteratorAb
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
Index offset = 0;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
offset += block_h + h * Base::Delta::kH + d * Base::Delta::kD;
|
||||
offset += thread_offset[1] + h * Base::Delta::kH + d * Base::Delta::kD;
|
||||
} else {
|
||||
offset += block_w + w * Base::Delta::kW;
|
||||
offset += thread_offset[2] + w * Base::Delta::kW;
|
||||
}
|
||||
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
@ -340,7 +310,7 @@ struct GemmGlobalIteratorAb
|
||||
/// Adds a vector offset to the iterator
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb & operator+=(Coord<3> const &offset) {
|
||||
|
||||
long long _offset = offset.template dot<long long>(
|
||||
LongIndex _offset = offset.template dot<LongIndex>(
|
||||
make_Coord(params.stride_d, params.stride_h, params.stride_w)
|
||||
);
|
||||
|
||||
@ -419,6 +389,8 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The index.
|
||||
typedef long long LongIndex;
|
||||
/// The thread offset
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
|
||||
@ -439,7 +411,7 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
|
||||
long long batch_stride,
|
||||
int stride_d_,
|
||||
Index ldm,
|
||||
Index bound,
|
||||
Index epilogue_stride_w,
|
||||
@ -447,7 +419,7 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
// The pointer.
|
||||
this->pointer = pointer;
|
||||
// Stride per batch
|
||||
stride_d = batch_stride;
|
||||
stride_d = stride_d_;
|
||||
// Each column of the matrix.
|
||||
stride_h = TileTraits_::ThreadsDelta::kH * ldm;
|
||||
// Each thread output 1 column per iteration. The stride between columns is given by the
|
||||
@ -463,6 +435,21 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long _stride_d, Index _stride_h,
|
||||
Index _inc_advance, Index _inc_h, Index _predicate_inc_advance, Index _predicate_inc_h,
|
||||
Index _predicate_offset) {
|
||||
this->pointer = pointer;
|
||||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
inc_advance = _inc_advance;
|
||||
inc_h = _inc_h;
|
||||
predicate_inc_advance = _predicate_inc_advance;
|
||||
predicate_inc_h = _predicate_inc_h;
|
||||
predicate_offset = _predicate_offset;
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters.
|
||||
@ -471,20 +458,7 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
Coord<4> thread_offset;
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block_offset,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Prepare the vector of predicates.
|
||||
for (int i = 0; i < Base::Iterations::kW; ++i) {
|
||||
predicates.set(i, thread_offset[2] + i * Base::Delta::kW < bounds[2]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
@ -527,7 +501,7 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
|
||||
/// Adds a vector offset to the iterator
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd & operator+=(Coord<3> const &offset) {
|
||||
long long _offset = offset.template dot<long long>(
|
||||
LongIndex _offset = offset.template dot<LongIndex>(
|
||||
make_Coord(params.stride_d, params.stride_h, 1)
|
||||
);
|
||||
params.pointer += _offset;
|
||||
@ -568,7 +542,7 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
}
|
||||
|
||||
/// add pointer offset
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset) { params.pointer += offset; }
|
||||
|
||||
/// Loads and increments iterator
|
||||
template <typename Fragment>
|
||||
|
||||
@ -92,7 +92,9 @@ struct SharedLoadStream {
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy() { iterator.load_post_increment(fetched[0]); }
|
||||
CUTLASS_DEVICE void copy() {
|
||||
iterator.load_post_increment(fetched[0]);
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(int step) { iterator.load(fetched[step % 2], step); }
|
||||
|
||||
@ -111,7 +111,7 @@ struct GlobalLoadStreamPair {
|
||||
CUTLASS_DEVICE GlobalLoadStreamPair(Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
ThreadblockTileRef const &threadblock_tile_ref,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0))
|
||||
: stream_a(params.stream_a,
|
||||
shared_storage.stream_a,
|
||||
@ -131,6 +131,13 @@ struct GlobalLoadStreamPair {
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GlobalLoadStreamPair & add_batch_offset(int batch_id) {
|
||||
stream_a.add_batch_offset(batch_id);
|
||||
stream_b.add_batch_offset(batch_id);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy() {
|
||||
stream_a.copy();
|
||||
|
||||
@ -418,6 +418,9 @@ struct GemmTraits {
|
||||
/// GEMM problem size
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// The K range for every partition except the last one
|
||||
int partitionK_range;
|
||||
|
||||
/// Parameters object for the global load stream
|
||||
typename GlobalLoadStream::Params global_to_shared_stream;
|
||||
|
||||
@ -433,6 +436,8 @@ struct GemmTraits {
|
||||
// Set the problem size.
|
||||
problem_size = desc.problem_size;
|
||||
|
||||
// there is no partitionK in the default case
|
||||
partitionK_range = problem_size[0];
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
@ -441,15 +446,18 @@ struct GemmTraits {
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Compute offset to residue.
|
||||
// partitionK_range <= problem_size[0]
|
||||
Index gemm_k = problem_size[0];
|
||||
Index offset_to_residue = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
|
||||
Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
|
||||
Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
|
||||
|
||||
// Initialize parameters objects for
|
||||
int error_code = global_to_shared_stream.stream_a.initialize(
|
||||
desc.A.data(),
|
||||
desc.batch_stride_A,
|
||||
desc.A.leading_dim(),
|
||||
offset_to_residue
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
@ -459,7 +467,8 @@ struct GemmTraits {
|
||||
desc.B.data(),
|
||||
desc.batch_stride_B,
|
||||
desc.B.leading_dim(),
|
||||
offset_to_residue
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
|
||||
if (error_code) {
|
||||
@ -516,7 +525,6 @@ struct GemmTraits {
|
||||
Index ldd,
|
||||
long long int batch_stride_D,
|
||||
Index batch_count) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, batch_count),
|
||||
alpha,
|
||||
@ -533,6 +541,121 @@ struct GemmTraits {
|
||||
|
||||
return this->initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a partitionedK GEMM params
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_) {
|
||||
// partitionK GEMM is a specialized batched stried gemm with different K ranges per batch
|
||||
// the problem_size of each batch is (lastK_size, n, m)
|
||||
// add more comments here
|
||||
// the k range for every batch excpet the last one
|
||||
//assert(partitionK_count_ > 0);
|
||||
partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_;
|
||||
// the k range of the last batch
|
||||
// int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range;
|
||||
int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1);
|
||||
int k_size = lastK_range;
|
||||
int lda = partitonK_desc.A.stride(0);
|
||||
int ldb = partitonK_desc.B.stride(0);
|
||||
int ldc = partitonK_desc.C.stride(0);
|
||||
int ldd = partitonK_desc.D.stride(0);
|
||||
int n = partitonK_desc.problem_size.n();
|
||||
|
||||
|
||||
long long int batch_stride_A = (kLayoutA == cutlass::MatrixLayout::kColumnMajor) ? lda * partitionK_range : partitionK_range;
|
||||
long long int batch_stride_B = (kLayoutB == cutlass::MatrixLayout::kColumnMajor) ? partitionK_range : partitionK_range * ldb;
|
||||
long long int batch_stride_C = ldc * n;
|
||||
long long int batch_stride_D = ldd * n;
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
//we pass lastK_size as per batch K. there is also a range that will match partitionK_size
|
||||
GemmCoord(k_size, partitonK_desc.problem_size.n(), partitonK_desc.problem_size.m(), partitionK_count_),
|
||||
partitonK_desc.alpha,
|
||||
partitonK_desc.A,
|
||||
batch_stride_A,
|
||||
partitonK_desc.B,
|
||||
batch_stride_B,
|
||||
partitonK_desc.beta,
|
||||
partitonK_desc.C,
|
||||
batch_stride_C,
|
||||
partitonK_desc.D,
|
||||
batch_stride_D
|
||||
);
|
||||
|
||||
// Set the problem size.
|
||||
problem_size = desc.problem_size;
|
||||
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Compute offset to residue.
|
||||
// partitionK_range <= problem_size[0]
|
||||
Index gemm_k = problem_size[0];
|
||||
Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
|
||||
Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
|
||||
|
||||
// Initialize parameters objects for
|
||||
int error_code = global_to_shared_stream.stream_a.initialize(
|
||||
desc.A.data(),
|
||||
desc.batch_stride_A,
|
||||
desc.A.leading_dim(),
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
error_code = global_to_shared_stream.stream_b.initialize(
|
||||
desc.B.data(),
|
||||
desc.batch_stride_B,
|
||||
desc.B.leading_dim(),
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// The epilogue.
|
||||
return epilogue.initialize(desc);
|
||||
}
|
||||
|
||||
|
||||
/// Helper to construct a partitionedK GEMM params
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
Index partitionK_count_) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, 1),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
TensorRef<ScalarD, 2>(d_d, ldd)
|
||||
);
|
||||
|
||||
|
||||
return this->initialize(desc, partitionK_count_);
|
||||
}
|
||||
};
|
||||
|
||||
// The storage for the main loop + prologue.
|
||||
|
||||
@ -100,10 +100,13 @@ struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_>
|
||||
|
||||
/// Constructor.
|
||||
CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& threadblock_offset,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: Base(_params, bounds, threadblock_offset, thread_offset_func), mask_(0xffffffff) {
|
||||
: Base(_params, threadblock_offset, thread_offset_func), mask_(0xffffffff) { }
|
||||
|
||||
CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& threadblock_offset) {
|
||||
|
||||
Base::initialize_predicates(bounds, threadblock_offset);
|
||||
// The number of elements read in a single iteration.
|
||||
int const kBlock = TileTraits_::Tile::kW;
|
||||
// The residue.
|
||||
|
||||
@ -71,6 +71,8 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)
|
||||
// The inputs.
|
||||
int const* a_int = reinterpret_cast<int const*>(&a[0]);
|
||||
int const* b_int = reinterpret_cast<int const*>(&b[0]);
|
||||
@ -82,6 +84,7 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
|
||||
: "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -80,7 +80,7 @@ struct IdentityBlockSwizzle {
|
||||
return grid;
|
||||
}
|
||||
|
||||
///
|
||||
///get threadblock offset, without considering tha batch dim
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
|
||||
dim3 block = swizzle();
|
||||
Coord<3> threadblock_offset =
|
||||
@ -93,6 +93,26 @@ struct IdentityBlockSwizzle {
|
||||
dim3 block = swizzle();
|
||||
return block.z;
|
||||
}
|
||||
|
||||
/// check if at the last partition
|
||||
CUTLASS_DEVICE bool is_last_partition() {
|
||||
if (get_batch_id() == (gridDim.z - 1))
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
|
||||
int partitionK_range) {
|
||||
// every partition except the last one has a smaller range
|
||||
// partitionK_range is the bounds for every partition except the last one
|
||||
// the last partition's bounds is the same with problem size
|
||||
if(is_last_partition())
|
||||
return problem_size.knm();
|
||||
else
|
||||
return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -226,6 +246,26 @@ struct ColumnMajorBlockSwizzle {
|
||||
dim3 block = swizzle();
|
||||
return block.z;
|
||||
}
|
||||
|
||||
/// check if at the last partition
|
||||
CUTLASS_DEVICE bool is_last_partition() {
|
||||
if (get_batch_id() == (gridDim.z - 1))
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
|
||||
int partitionK_range) {
|
||||
// every partition except the last one has a smaller range
|
||||
// partitionK_range is the bounds for every partition except the last one
|
||||
// the last partition's bounds is the same with problem size
|
||||
if (is_last_partition())
|
||||
return problem_size.knm();
|
||||
else
|
||||
return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -379,6 +419,26 @@ struct RowMajorBlockSwizzle {
|
||||
dim3 block = swizzle();
|
||||
return block.z;
|
||||
}
|
||||
|
||||
/// check if at the last partition
|
||||
CUTLASS_DEVICE bool is_last_partition() {
|
||||
if (get_batch_id() == (gridDim.z - 1) )
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
|
||||
int partitionK_range) {
|
||||
// every partition except the last one has a smaller range
|
||||
// partitionK_range is the bounds for every partition except the last one
|
||||
// the last partition's bounds is the same with problem size
|
||||
if (is_last_partition())
|
||||
return problem_size.knm();
|
||||
else
|
||||
return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -45,7 +45,7 @@ namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename EpilogueFunctor_, typename Index_ = int>
|
||||
template <typename GemmConfig_, typename Accumulator_, typename EpilogueFunctor_, typename Index_ = int>
|
||||
struct WmmaGemmEpilogueTraitsHelper {
|
||||
/// The scalar.
|
||||
typedef typename EpilogueFunctor_::Scalar Scalar;
|
||||
@ -104,7 +104,10 @@ struct WmmaGemmEpilogueTraitsHelper {
|
||||
// The number of threads.
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsD>
|
||||
GemmConfig_::kScalarsPerLdsD,
|
||||
// this parameter helps with swizzling when accum is fp32 and output is fp16
|
||||
sizeof(Accumulator_) / sizeof(typename GemmConfig_::ScalarD)
|
||||
>
|
||||
SharedLoadTileTraits;
|
||||
|
||||
/// The iterator to load D from shared memory.
|
||||
|
||||
@ -103,18 +103,18 @@ struct WmmaGemmGlobalIteratorCd : public GemmGlobalIteratorCd<TileTraits_, Index
|
||||
Index epilogue_stride_w,
|
||||
Index epilogue_delta_w) {
|
||||
// The pointer.
|
||||
BaseParams::pointer = pointer;
|
||||
this->pointer = pointer;
|
||||
// Stride between GEMMs
|
||||
BaseParams::stride_d = batch_stride;
|
||||
this->stride_d = batch_stride;
|
||||
// Setup the base stride. One "group of threads" per column.
|
||||
BaseParams::stride_h = ldm;
|
||||
this->stride_h = ldm;
|
||||
// Each thread output 1 column per iteration. .
|
||||
BaseParams::inc_h = ldm * TileTraits_::Threads::kH;
|
||||
BaseParams::inc_advance = BaseParams::inc_h + epilogue_stride_w;
|
||||
this->inc_h = ldm * TileTraits_::Threads::kH;
|
||||
this->inc_advance = this->inc_h + epilogue_stride_w;
|
||||
|
||||
BaseParams::predicate_offset = n;
|
||||
BaseParams::predicate_inc_h = TileTraits_::Threads::kH;
|
||||
BaseParams::predicate_inc_advance = BaseParams::predicate_inc_h + epilogue_delta_w;
|
||||
this->predicate_offset = n;
|
||||
this->predicate_inc_h = TileTraits_::Threads::kH;
|
||||
this->predicate_inc_advance = this->predicate_inc_h + epilogue_delta_w;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -173,6 +173,7 @@ struct WmmaGemmSharedStoreTileDTraits {
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> ImmediateOffsetStrides;
|
||||
|
||||
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -192,7 +193,7 @@ struct WmmaGemmSharedStoreTileDTraits {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerLds_>
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerLds_, int kLdsPerAccess_ = 1>
|
||||
struct WmmaGemmSharedLoadTileDTraits {
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
@ -201,7 +202,7 @@ struct WmmaGemmSharedLoadTileDTraits {
|
||||
/// The access size
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<Tile_, kScalarsPerLds_>::Tile Tile;
|
||||
typedef typename WmmaReshapeTile<Tile_, kScalarsPerLds_, kLdsPerAccess_>::Tile Tile;
|
||||
/// The threads.
|
||||
typedef typename ReshapeThreads<Tile, Threads_>::Threads Threads;
|
||||
/// The threads strides.
|
||||
@ -212,12 +213,13 @@ struct WmmaGemmSharedLoadTileDTraits {
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_>
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_, kScalarsPerLds_>
|
||||
ImmediateOffsetStrides;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kScalarsPerLds_>
|
||||
Iterations;
|
||||
|
||||
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -46,7 +46,7 @@ namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
@ -68,7 +68,18 @@ template <
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_>
|
||||
int kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for A.
|
||||
int KScalarsPerLdsA_,
|
||||
/// The number of scalars per LDS for B.
|
||||
int KscalarsPerLdsB_,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
int kScalarsPerLdgCAndStgD_,
|
||||
/// The number of scalars per STS for D.
|
||||
int kScalarsPerStsD_,
|
||||
/// The number of scalars per LDS for D.
|
||||
int kScalarsPerLdsD_
|
||||
>
|
||||
struct WmmaGemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
ScalarA_,
|
||||
@ -94,19 +105,19 @@ struct WmmaGemmConfig : public GemmConfig<
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
KScalarsPerLdsA_,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
KscalarsPerLdsB_,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
16 / sizeof(ScalarC_),
|
||||
kScalarsPerLdgCAndStgD_,
|
||||
/// The number of scalars per STS for D.
|
||||
16 / sizeof(Accumulator_),
|
||||
kScalarsPerStsD_,
|
||||
/// The number of scalars per LDS for D.
|
||||
16 / sizeof(Accumulator_),
|
||||
kScalarsPerLdsD_,
|
||||
/// The number of stages in shared memory.
|
||||
1,
|
||||
/// If true, residue is computed in mainloop. If false, separate loops are instantiated.
|
||||
@ -955,6 +966,16 @@ template <
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for A.
|
||||
int KScalarsPerLdsA_,
|
||||
/// The number of scalars per LDS for B.
|
||||
int KscalarsPerLdsB_,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
int kScalarsPerLdgCAndStgD_,
|
||||
/// The number of scalars per STS for D.
|
||||
int kScalarsPerStsD_,
|
||||
/// The number of scalars per LDS for D.
|
||||
int kScalarsPerLdsD_,
|
||||
/// The index.
|
||||
typename Index_>
|
||||
struct WmmaGemmTraitsHelper {
|
||||
@ -969,7 +990,13 @@ struct WmmaGemmTraitsHelper {
|
||||
WarpGemmShape_,
|
||||
InstructionShape_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_>
|
||||
kScalarsPerLdgB_,
|
||||
KScalarsPerLdsA_,
|
||||
KscalarsPerLdsB_,
|
||||
kScalarsPerLdgCAndStgD_,
|
||||
kScalarsPerStsD_,
|
||||
kScalarsPerLdsD_
|
||||
>
|
||||
GemmConfig;
|
||||
|
||||
/// The GEMM config for A.
|
||||
@ -1042,7 +1069,7 @@ struct WmmaGemmTraitsHelper {
|
||||
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
|
||||
|
||||
/// The helper to create the epilogue traits.
|
||||
typedef WmmaGemmEpilogueTraitsHelper<GemmConfig, EpilogueFunctor_, Index_> EpilogueTraitsHelper;
|
||||
typedef WmmaGemmEpilogueTraitsHelper<GemmConfig, Accumulator_, EpilogueFunctor_, Index_> EpilogueTraitsHelper;
|
||||
/// The traits class for the epilogue.
|
||||
typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_, EpilogueTraitsHelper>
|
||||
GemmEpilogueTraits;
|
||||
@ -1084,6 +1111,16 @@ template <
|
||||
int kScalarsPerLdgA_ = 8,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 8,
|
||||
/// The number of scalars per LDS for A.
|
||||
int KScalarsPerLdsA_ = 8,
|
||||
/// The number of scalars per LDS for B.
|
||||
int KscalarsPerLdsB_ = 8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
int kScalarsPerLdgCAndStgD_ = 16 / sizeof(ScalarC_),
|
||||
/// The number of scalars per STS for D.
|
||||
int kScalarsPerStsD_ = 16 / sizeof(Accumulator_),
|
||||
/// The number of scalars per LDS for D.
|
||||
int kScalarsPerLdsD_ = 16 / sizeof(Accumulator_),
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class.
|
||||
@ -1099,6 +1136,11 @@ template <
|
||||
InstructionShape_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_,
|
||||
KScalarsPerLdsA_,
|
||||
KscalarsPerLdsB_,
|
||||
kScalarsPerLdgCAndStgD_,
|
||||
kScalarsPerStsD_,
|
||||
kScalarsPerLdsD_,
|
||||
Index_> >
|
||||
struct WmmaGemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
|
||||
@ -153,7 +153,7 @@ struct MatrixCoord : public Coord<2, int> {
|
||||
//
|
||||
// Coord<TensorRefMapFunc::kStorageRank> stride = TensorRefMapFunc::stride(leading_dim);
|
||||
//
|
||||
struct MatrixLayout {
|
||||
namespace MatrixLayout {
|
||||
|
||||
/// Enumeration defining fundamental contiguous layouts.
|
||||
enum Kind { kRowMajor, kColumnMajor };
|
||||
|
||||
175
cutlass/reduction/batched_reduction.h
Normal file
175
cutlass/reduction/batched_reduction.h
Normal file
@ -0,0 +1,175 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements a software-pipelined efficient batched reduction.
|
||||
D = alpha * Reduction(A) + beta * C
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reduction {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename batched_reduction_>
|
||||
__global__ __launch_bounds__(batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_::Params params) {
|
||||
// Construct the batched_reduction object
|
||||
batched_reduction_ batched_reduction(params);
|
||||
batched_reduction.run();
|
||||
}
|
||||
|
||||
template <typename BatchedReductionTraits_>
|
||||
struct BatchedReduction {
|
||||
/// This class
|
||||
typedef BatchedReduction<BatchedReductionTraits_> This_;
|
||||
/// The traits
|
||||
typedef BatchedReductionTraits_ Traits;
|
||||
/// Params
|
||||
typedef typename Traits::Params Params;
|
||||
/// functor
|
||||
typedef typename Traits::Functor Functor;
|
||||
|
||||
/// ctor
|
||||
CUTLASS_DEVICE BatchedReduction(Params const ¶ms_)
|
||||
: params(params_), functor(params_.functorParams) {}
|
||||
|
||||
/// main operation method
|
||||
/// D = alpha * Reduction(A) + beta * C
|
||||
CUTLASS_DEVICE void run() {
|
||||
#if (__CUDA_ARCH__ >= 600)
|
||||
// Swizzle the IDs of the block
|
||||
typename Traits::BlockSwizzle block_swizzle;
|
||||
Coord<3> threadblock_offset =
|
||||
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::SubTile>());
|
||||
|
||||
int subTileSize = gridDim.x * Traits::SubTile::kW;
|
||||
int tileSize = params.problem_size[1] * params.problem_size[2];
|
||||
int subTileOffset = threadblock_offset[2] + threadIdx.x * Traits::ThreadShape::kW;
|
||||
|
||||
int subTileBase = 0;
|
||||
|
||||
typename Traits::ScalarA inRegs[Traits::maxInReg];
|
||||
typename Traits::ScalarAccum AccumRegs[Traits::maxOutReg];
|
||||
|
||||
for (int subTile = 0; subTile < tileSize; subTile += subTileSize) {
|
||||
int tileOffset = subTileBase + subTileOffset;
|
||||
// Init AccumRegs
|
||||
for (int i = 0; i < Traits::ThreadShape::kW; i++)
|
||||
AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(0.0f);
|
||||
// Fetch c0
|
||||
typename Traits::ScalarAccum c0[Traits::ThreadShape::kW];
|
||||
for (int i = 0; i< Traits::ThreadShape::kW; i++)
|
||||
c0[i] = static_cast<typename Traits::ScalarAccum>(params.d_c[tileOffset + i]);
|
||||
|
||||
// Fetch partial sums from A
|
||||
#pragma unroll
|
||||
for (int s = 0; s < Traits::ReductionSize; s++) {
|
||||
int inRegOffset = s * Traits::ThreadShape::kW;
|
||||
int dOffset = (s * tileSize) + tileOffset;
|
||||
#pragma unroll
|
||||
for (int i = 0; i< Traits::ThreadShape::kW; i++) {
|
||||
inRegs[inRegOffset + i] = params.d_a[dOffset + i];
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate
|
||||
#pragma unroll
|
||||
for (int s = 0; s < Traits::ReductionSize; s++) {
|
||||
int inRegOffset = s * Traits::ThreadShape::kW;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
|
||||
//AccumRegs[i] = cuFma(params.alpha, inRegs[inRegOffset + i], AccumRegs[i]);
|
||||
//AccumRegs[i] = params.alpha * inRegs[inRegOffset + i] + AccumRegs[i];
|
||||
AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(inRegs[inRegOffset + i]) + AccumRegs[i];
|
||||
}
|
||||
}
|
||||
// calling functor
|
||||
functor_caller<Traits::ThreadShapeMultiple2>(AccumRegs, c0, AccumRegs);
|
||||
|
||||
// Store AccumRegs to D
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
|
||||
params.d_d[tileOffset + i] = static_cast<typename Traits::ScalarD>(AccumRegs[i]);
|
||||
}
|
||||
|
||||
// Advance sub-tile pointer
|
||||
subTileBase += subTileSize;
|
||||
} // end for loop
|
||||
#endif //#if (__CUDA_ARCH__ >= 600)
|
||||
}
|
||||
|
||||
template<bool ThreadShapeMultiple2>
|
||||
CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output) {
|
||||
if (ThreadShapeMultiple2 == true) {
|
||||
for (int i = 0; i < Traits::ThreadShape::kW / 2; i++) {
|
||||
functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 2>(&accum[2 * i], &old[2 * i], &output[2 * i]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
|
||||
functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 1>(&accum[i], &old[i], &output[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Static function members
|
||||
//
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Launch the kernel.
|
||||
static __host__ cudaError_t launch(Params const& params,
|
||||
cudaStream_t stream = cudaStreamDefault) {
|
||||
// Setup the grid.
|
||||
typename Traits::BlockSwizzle block_swizzle;
|
||||
dim3 grid = block_swizzle.get_grid_layout(params.problem_size,
|
||||
make_Coord_from_shape<typename Traits::OutputTile>());
|
||||
|
||||
dim3 block;
|
||||
block.x = Traits::kThreads;
|
||||
batched_reduction_kernel<This_><<<grid, block, 0, stream>>>(params);
|
||||
return cudaGetLastError();
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
// The functor.
|
||||
Functor functor;
|
||||
};
|
||||
|
||||
} // namespace reduction
|
||||
} // namespace cutlass
|
||||
192
cutlass/reduction/batched_reduction_traits.h
Normal file
192
cutlass/reduction/batched_reduction_traits.h
Normal file
@ -0,0 +1,192 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of complete batched reduction.
|
||||
D = alpha * Reduction(A) + beta * C
|
||||
*/
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/shape.h"
|
||||
#include "cutlass/reduction/threadblock_swizzle.h"
|
||||
#include "cutlass/reduction/batched_reduction.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reduction {
|
||||
|
||||
/*
|
||||
OutputTile defines the work load per thread block
|
||||
Subtile defines the work load per thread block per iteration
|
||||
OutputTile / Subtile = number of iterations within a kernel
|
||||
ThreadShape defines the work load per thread
|
||||
Subtile / ThreadShape = number of threads per thread block
|
||||
*/
|
||||
template <
|
||||
/// The scalar type for A
|
||||
typename ScalarA_,
|
||||
/// The scalar type for C
|
||||
typename ScalarC_,
|
||||
/// The scalar type for D
|
||||
typename ScalarD_,
|
||||
/// the scalar type for alpha,
|
||||
typename ScalarAlphaBeta_,
|
||||
/// The scalar type for accumulator
|
||||
typename ScalarAccum_,
|
||||
/// Reduction work load per batch
|
||||
int ReductionSize_ = 1,
|
||||
/// The output tile, work load per thread block,
|
||||
typename OutputTile_ = Shape<1, 1, 128>,
|
||||
/// The subtile
|
||||
typename SubTile_ = Shape<1, 1, 64>,
|
||||
/// Work load per thread, per subtile
|
||||
typename ThreadShape_ = Shape<1, 1, 2>,
|
||||
/// The index
|
||||
typename Index_ = int,
|
||||
/// The block swizzle to reorganize the grid.
|
||||
typename BlockSwizzle_ = DefaultBlockSwizzle,
|
||||
/// The input register vector size in kernel
|
||||
int maxInReg_ = 160,
|
||||
/// The output register vector size in kernel
|
||||
int maxOutReg_ = 64,
|
||||
/// The functor that will be executed at the end
|
||||
typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >
|
||||
>
|
||||
struct BatchedReductionTraits {
|
||||
///
|
||||
typedef BatchedReductionTraits<ScalarA_,
|
||||
ScalarC_,
|
||||
ScalarD_,
|
||||
ScalarAlphaBeta_,
|
||||
ScalarAccum_,
|
||||
ReductionSize_,
|
||||
OutputTile_,
|
||||
SubTile_,
|
||||
ThreadShape_,
|
||||
Index_,
|
||||
BlockSwizzle_,
|
||||
maxInReg_,
|
||||
maxOutReg_,
|
||||
Functor_> This_;
|
||||
/// The struct that consumes this Traits
|
||||
typedef typename cutlass::reduction::BatchedReduction<This_> KernelClass;
|
||||
///
|
||||
typedef OutputTile_ OutputTile;
|
||||
///
|
||||
typedef SubTile_ SubTile;
|
||||
///
|
||||
typedef ThreadShape_ ThreadShape;
|
||||
/// The input pointer type
|
||||
typedef ScalarA_ ScalarA;
|
||||
///
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The output pointer type
|
||||
typedef ScalarD_ ScalarD;
|
||||
/// The alpha beta type
|
||||
typedef ScalarAlphaBeta_ ScalarAlphaBeta;
|
||||
/// The type for accumulation
|
||||
typedef ScalarAccum_ ScalarAccum;
|
||||
/// The index
|
||||
typedef Index_ Index;
|
||||
/// The thread block swizzle
|
||||
typedef BlockSwizzle_ BlockSwizzle;
|
||||
///
|
||||
static const int ReductionSize = ReductionSize_;
|
||||
/// check if threadShape is multiple of 2.
|
||||
static const bool ThreadShapeMultiple2 = (ThreadShape::kW % 2 == 0);
|
||||
///
|
||||
typedef Functor_ Functor;
|
||||
/// Parameteres object constructable on the host
|
||||
/// The number of threads per thread block. can be deduced
|
||||
static int const kThreads = SubTile::kW / ThreadShape::kW;
|
||||
//
|
||||
static int const maxInReg = maxInReg_;
|
||||
//
|
||||
static int const maxOutReg = maxOutReg_;
|
||||
//
|
||||
static_assert(SubTile::kW % ThreadShape::kW == 0, "cannot evenly distribute work load among threads");
|
||||
//
|
||||
static_assert(kThreads % 32 == 0, "threads per threadblock is not multiple of 32");
|
||||
//
|
||||
static_assert(OutputTile::kW % SubTile::kW == 0, "cannot evenly distribute work load among iterations");
|
||||
//
|
||||
static_assert(ReductionSize * ThreadShape::kW <= maxInReg, "ReductionSize * ThreadShape::kW should not be bigger than maxInReg");
|
||||
//
|
||||
static_assert(ThreadShape::kW <= maxOutReg, "ThreadShape::kW should not be bigger than maxOutReg");
|
||||
|
||||
struct Params {
|
||||
/// The dimension of output tensor
|
||||
Coord<3> problem_size;
|
||||
/// The alpha
|
||||
ScalarAlphaBeta alpha;
|
||||
/// The beta
|
||||
ScalarAlphaBeta beta;
|
||||
/// stride between two element that will be sumed
|
||||
long long int reduction_stride;
|
||||
//
|
||||
ScalarA const *d_a;
|
||||
//
|
||||
Index lda;
|
||||
//
|
||||
ScalarC const *d_c;
|
||||
//
|
||||
Index ldc;
|
||||
//
|
||||
ScalarD *d_d;
|
||||
//
|
||||
Index ldd;
|
||||
/// The functor params.
|
||||
typename Functor::Params functorParams;
|
||||
/// Initialize the parameters for 2D output tensor
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m_,
|
||||
Index n_,
|
||||
ScalarAlphaBeta alpha_,
|
||||
ScalarAlphaBeta beta_,
|
||||
long long int reduction_stride_,
|
||||
ScalarA const *d_a_,
|
||||
Index lda_,
|
||||
ScalarC const *d_c_,
|
||||
Index ldc_,
|
||||
ScalarD *d_d_,
|
||||
Index ldd_){
|
||||
problem_size = make_Coord(1, n_, m_);
|
||||
alpha = alpha_;
|
||||
beta = beta_;
|
||||
reduction_stride = reduction_stride_;
|
||||
d_a = d_a_;
|
||||
lda = lda_;
|
||||
d_c = d_c_;
|
||||
d_d = d_d_;
|
||||
ldc = ldc_;
|
||||
ldd = ldd_;
|
||||
|
||||
functorParams.initialize(alpha_, beta_);
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
};
|
||||
} // namespace reduction
|
||||
} // namespace cutlass
|
||||
61
cutlass/reduction/threadblock_swizzle.h
Normal file
61
cutlass/reduction/threadblock_swizzle.h
Normal file
@ -0,0 +1,61 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies functors for mapping blockIdx to partitions of the batched reduction computation.
|
||||
*/
|
||||
#pragma once
|
||||
#include "cutlass/coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reduction {
|
||||
struct DefaultBlockSwizzle {
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {}
|
||||
|
||||
/// Swizzle the block index.
|
||||
CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size,
|
||||
Coord<3> const &OutputTile) {
|
||||
assert(OutputTile[0] == 1 && OutputTile[1] == 1);
|
||||
assert((problem_size[0] * problem_size[1] * problem_size[2]) % OutputTile[2] == 0);
|
||||
dim3 grid;
|
||||
grid.x = problem_size[0] * problem_size[1] * problem_size[2]
|
||||
/ OutputTile[2] ;
|
||||
return grid;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &SubTile) {
|
||||
assert(SubTile[0] == 1 && SubTile[1] == 1);
|
||||
dim3 block = swizzle();
|
||||
Coord<3> threadblock_offset =
|
||||
make_Coord(0, 0, block.x * SubTile[2]);
|
||||
return threadblock_offset;
|
||||
}
|
||||
};
|
||||
} // namespace reduction
|
||||
} // namespace cutlass
|
||||
@ -53,6 +53,22 @@ struct ReshapeTile<Tile_, kAccessSize_, true> {
|
||||
typedef Shape<Tile_::kD, Tile_::kH, Tile_::kW / kAccessSize_, kAccessSize_> Tile;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename Tile_, int kAccessSize_, int kLdsPerAccess_, bool = (Tile_::kC < (kAccessSize_ * kLdsPerAccess_))>
|
||||
struct WmmaReshapeTile {
|
||||
typedef Tile_ Tile;
|
||||
};
|
||||
|
||||
template <typename Tile_, int kAccessSize_, int kLdsPerAccess_>
|
||||
struct WmmaReshapeTile<Tile_, kAccessSize_, kLdsPerAccess_, true> {
|
||||
// Make sure the W dimension of the tile is large enough.
|
||||
static_assert(Tile_::kW >= (kAccessSize_ * kLdsPerAccess_), "The W dimension is too small");
|
||||
// Make sure the dimension can be divided by the number of scalars.
|
||||
static_assert(Tile_::kW % (kAccessSize_ * kLdsPerAccess_) == 0, "Not supported");
|
||||
// Collapse the W dimension.
|
||||
typedef Shape<Tile_::kD, Tile_::kH, Tile_::kW / (kAccessSize_ * kLdsPerAccess_), (kAccessSize_ * kLdsPerAccess_)> Tile;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -23,7 +23,7 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Introduces TensorRefCollection concept and defines TensorRefBatch and TensorRefArray.
|
||||
\brief Introduces TensorRefCollection concept and defines TensorRefBatch and TensorRefArray.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@ -85,7 +85,7 @@ template <
|
||||
/// Index type used for offsets and pointer differences
|
||||
typename LongIndex_ = long long
|
||||
>
|
||||
struct TensorRefBatchStrided:
|
||||
struct TensorRefBatchStrided:
|
||||
public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> {
|
||||
|
||||
//
|
||||
@ -98,12 +98,16 @@ struct TensorRefBatchStrided:
|
||||
/// Storage type
|
||||
typedef typename Base::Storage Storage;
|
||||
|
||||
/// Rank of the logical tensor
|
||||
static int const kRank = Rank_;
|
||||
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Typically, strides in memory can be very large
|
||||
typedef LongIndex_ LongIndex;
|
||||
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
typedef Coord<kRank> TensorCoord;
|
||||
|
||||
@ -121,7 +125,7 @@ struct TensorRefBatchStrided:
|
||||
/// Reference to the parent TensorBatchRef object
|
||||
TensorRefBatchStrided const &ref_;
|
||||
|
||||
/// Offset from the base TensorRef pointer
|
||||
/// Offset from the base TensorRef pointer
|
||||
LongIndex offset_;
|
||||
|
||||
public:
|
||||
@ -129,12 +133,12 @@ struct TensorRefBatchStrided:
|
||||
/// Constructs a ConstIterator from a parent TensorRefBatchStrided
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator(
|
||||
TensorRefBatchStrided const &ref,
|
||||
TensorRefBatchStrided const &ref,
|
||||
LongIndex offset = 0): ref_(ref), offset_(offset) { }
|
||||
|
||||
/// Obtains a TensorRef pointed to by the iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef *operator() const {
|
||||
TensorRef operator*() const {
|
||||
TensorRef ref(ref_);
|
||||
ref.add_pointer_offset(offset_);
|
||||
return ref;
|
||||
@ -158,7 +162,7 @@ struct TensorRefBatchStrided:
|
||||
/// Returns an iterator advanced by (idx) amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator+(Index idx) {
|
||||
return ConstIterator(ref, offset_ + ref_.tensor_stride * idx);
|
||||
return ConstIterator(ref_, offset_ + ref_.tensor_stride * idx);
|
||||
}
|
||||
|
||||
/// Advances this iterator by (idx) and returns a reference to self
|
||||
@ -198,7 +202,7 @@ struct TensorRefBatchStrided:
|
||||
|
||||
/// Returns the difference in offset between two iterators
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride operator-(ConstIterator const &it) {
|
||||
LongIndex operator-(ConstIterator const &it) {
|
||||
return offset_ - it.offset_;
|
||||
}
|
||||
};
|
||||
@ -218,10 +222,10 @@ struct TensorRefBatchStrided:
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefBatchStrided(): tensor_stride(0) { }
|
||||
|
||||
// Constructs form a tensor reference and
|
||||
// Constructs form a tensor reference and
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefBatchStrided(TensorRef const &ref, LongIndex _tensor_stride = 0):
|
||||
TensorRef(ref),
|
||||
TensorRefBatchStrided(TensorRef const &ref, LongIndex _tensor_stride = 0):
|
||||
TensorRef(ref),
|
||||
tensor_stride(_tensor_stride) { }
|
||||
|
||||
/// Gets the pointer offset
|
||||
@ -232,7 +236,7 @@ struct TensorRefBatchStrided:
|
||||
|
||||
// Returns a reference
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef at(Index idx) const {
|
||||
TensorRef at(Index idx = 0) const {
|
||||
TensorRef ref(*this);
|
||||
ref.add_pointer_offset(get_pointer_offset(idx));
|
||||
return ref;
|
||||
@ -245,6 +249,30 @@ struct TensorRefBatchStrided:
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper to construct a TensorRefBatchStrided<> object using type deduction
|
||||
template <typename TensorRef_>
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefBatchStrided<
|
||||
typename TensorRef_::Storage,
|
||||
TensorRef_::kRank,
|
||||
typename TensorRef_::MapFunc,
|
||||
TensorRef_::kStorageGrank,
|
||||
typename TensorRef_::Index,
|
||||
typename TensorRef_::LongIndex
|
||||
> make_TensorRefBatchStrided(
|
||||
TensorRef_ const &ref,
|
||||
typename TensorRef_::LongIndex batch_stride = 0) {
|
||||
|
||||
return TensorRefBatchStrided<
|
||||
typename TensorRef_::Storage,
|
||||
TensorRef_::kRank,
|
||||
typename TensorRef_::MapFunc,
|
||||
TensorRef_::kStorageGrank,
|
||||
typename TensorRef_::Index,
|
||||
typename TensorRef_::LongIndex
|
||||
>(ref, batch_stride);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This satisfies TensorRefCollection and stores a collection of TensorRef objects. This is a
|
||||
@ -253,7 +281,7 @@ struct TensorRefBatchStrided:
|
||||
/// Note, TensorRef maps a logical coordinate space to an n-D array with rank kStorageRank. It
|
||||
/// maintains a stride vector of similar rank, but the least significant rank is defined to be 1.
|
||||
///
|
||||
/// The least significant stride of 1 is not stored, and therefore the number of stride arrays is
|
||||
/// The least significant stride of 1 is not stored, and therefore the number of stride arrays is
|
||||
/// kStorageRank - 1.
|
||||
template <
|
||||
/// Data type of element stored within tensor
|
||||
@ -274,9 +302,6 @@ struct TensorRefArray {
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// TensorRef type obtained from the TensorRefArray
|
||||
typedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> TensorRef;
|
||||
|
||||
/// Element pointed to by the TensorRef
|
||||
typedef Storage_ Storage;
|
||||
|
||||
@ -287,16 +312,17 @@ struct TensorRefArray {
|
||||
typedef LongIndex_ LongIndex;
|
||||
|
||||
/// Rank of the stride vector
|
||||
static int const kStorageRank = TensorRef::kStorageRank;
|
||||
static int const kStorageRank = StorageRank_;
|
||||
|
||||
/// TensorRefIterator over TensorRef objects in TensorRefArray
|
||||
/// TensorRefIterator over TensorRef objects in TensorRefArray
|
||||
class ConstIterator {
|
||||
public:
|
||||
|
||||
/// TensorRef returned by the iterator
|
||||
typedef Base TensorRef;
|
||||
/// Containing class's tensor rev
|
||||
typedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> TensorRef;
|
||||
|
||||
private:
|
||||
|
||||
/// Reference to the TensorRefArray
|
||||
TensorRefArray const &ref_;
|
||||
|
||||
@ -307,11 +333,11 @@ struct TensorRefArray {
|
||||
|
||||
/// Constructs a ConstIterator over the TensorRef objects
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator(TensorArrayRef const &ref, int idx = 0): ref_(ref), idx_(idx) { }
|
||||
ConstIterator(TensorRefArray const &ref, int idx = 0): ref_(ref), idx_(idx) { }
|
||||
|
||||
/// Obtains a TensorRef pointed to by this iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef *operator() const {
|
||||
TensorRef operator*() const {
|
||||
return ref_.reference(idx_);
|
||||
}
|
||||
|
||||
@ -367,6 +393,9 @@ struct TensorRefArray {
|
||||
}
|
||||
};
|
||||
|
||||
/// TensorRef type obtained from the TensorRefArray
|
||||
typedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> TensorRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
@ -383,13 +412,13 @@ struct TensorRefArray {
|
||||
|
||||
// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorArrayRef() { }
|
||||
TensorRefArray() { }
|
||||
|
||||
// Construct from pointers to arrays to strides
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorArrayRef(
|
||||
TensorRefArray(
|
||||
Storage **_pointers,
|
||||
Index _strides[kStorageRank - 1]): pointers(_pointers) {
|
||||
Index _strides[kStorageRank - 1]): pointers(_pointers) {
|
||||
|
||||
// Copy pointers to strides arrays
|
||||
for (int i = 0; i < kStorageRank - 1; ++i) {
|
||||
@ -399,11 +428,11 @@ struct TensorRefArray {
|
||||
|
||||
// Returns a TensorRef at the given index in the collection
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef at(Index idx) const {
|
||||
TensorRef at(Index idx = 0) const {
|
||||
Coord<kStorageRank - 1, Index> stride;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kStorageRank - 1; ++i) {
|
||||
stride[i] = stride_[idx][i];
|
||||
stride[i] = strides[idx][i];
|
||||
}
|
||||
return TensorRef(pointers[idx], stride);
|
||||
}
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
#include "cutlass/shape.h"
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/zip_tensor_ref.h"
|
||||
|
||||
namespace cutlass {
|
||||
@ -61,6 +62,12 @@ struct TileAllocation {
|
||||
/// Defines the tensor reference for this allocation
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// View of memory
|
||||
typedef TensorView<Scalar const, 4> ConstTensorView;
|
||||
|
||||
/// View of memory
|
||||
typedef TensorView<Scalar, 4> TensorView;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
@ -91,6 +98,24 @@ struct TileAllocation {
|
||||
ConstTensorRef reference() const {
|
||||
return ConstTensorRef(data(), make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC));
|
||||
}
|
||||
|
||||
/// Returns a TensorView object pointing to the data
|
||||
CUTLASS_DEVICE
|
||||
TensorView view() {
|
||||
return TensorView(
|
||||
data(),
|
||||
make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC),
|
||||
make_Coord(Shape::kD, Shape::kH, Shape::kW, Shape::kC));
|
||||
}
|
||||
|
||||
/// Returns a TensorView object pointing to the data
|
||||
CUTLASS_DEVICE
|
||||
ConstTensorView view() const {
|
||||
return TensorView(
|
||||
data(),
|
||||
make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC),
|
||||
make_Coord(Shape::kD, Shape::kH, Shape::kW, Shape::kC));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -163,6 +163,9 @@ struct TileIteratorBase {
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Long index
|
||||
typedef long long LongIndex;
|
||||
|
||||
/// Skew quantity
|
||||
typedef Skew_ Skew;
|
||||
|
||||
@ -216,15 +219,15 @@ struct TileIteratorBase {
|
||||
// Dat members
|
||||
//
|
||||
|
||||
long long stride_d;
|
||||
Index stride_d;
|
||||
Index stride_h;
|
||||
Index stride_w;
|
||||
|
||||
long long inc_d;
|
||||
Index inc_d;
|
||||
Index inc_h;
|
||||
Index inc_w;
|
||||
|
||||
long long inc_advance;
|
||||
Index inc_advance;
|
||||
|
||||
//
|
||||
// Methods
|
||||
@ -236,13 +239,13 @@ struct TileIteratorBase {
|
||||
|
||||
/// Constructs params
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(long long _stride_d,
|
||||
Params(Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
long long _inc_d,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
long long _inc_advance)
|
||||
Index _inc_advance)
|
||||
: stride_d(_stride_d),
|
||||
stride_h(_stride_h),
|
||||
stride_w(_stride_w),
|
||||
@ -259,13 +262,13 @@ struct TileIteratorBase {
|
||||
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(long long _stride_d,
|
||||
int initialize(Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
long long _inc_d,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
long long _inc_advance) {
|
||||
Index _inc_advance) {
|
||||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
stride_w = _stride_w;
|
||||
@ -286,14 +289,14 @@ struct TileIteratorBase {
|
||||
|
||||
/// Initializes the parameters object from a vector of strides
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(long long _stride_d, Index _stride_h, Index _stride_w) {
|
||||
int initialize(Index _stride_d, Index _stride_h, Index _stride_w) {
|
||||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
stride_w = _stride_w;
|
||||
|
||||
inc_w = stride_w * Delta::kW;
|
||||
inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1);
|
||||
inc_d = stride_d * Delta::kD - stride_h * Delta::kH * (Iterations::kH - 1) -
|
||||
inc_d = stride_h * Delta::kD - stride_h * Delta::kH * (Iterations::kH - 1) -
|
||||
stride_w * Delta::kW * (Iterations::kW - 1);
|
||||
|
||||
inc_advance = 0;
|
||||
@ -310,7 +313,7 @@ struct TileIteratorBase {
|
||||
inc_advance = Tile::kD * stride_d;
|
||||
}
|
||||
|
||||
inc_advance -= stride_d * Delta::kD * (Iterations::kD - 1) +
|
||||
inc_advance -= stride_h * Delta::kD * (Iterations::kD - 1) +
|
||||
stride_h * Delta::kH * (Iterations::kH - 1) +
|
||||
stride_w * Delta::kW * (Iterations::kW - 1);
|
||||
|
||||
@ -436,6 +439,9 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
/// Index type
|
||||
typedef typename Base::Index Index;
|
||||
|
||||
/// Index type
|
||||
typedef typename Base::LongIndex LongIndex;
|
||||
|
||||
/// Skew quantity
|
||||
typedef typename Base::Skew Skew;
|
||||
|
||||
@ -513,10 +519,10 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar const *ptr,
|
||||
long long _stride_d,
|
||||
Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
long long _inc_d,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance)
|
||||
@ -527,7 +533,7 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w)
|
||||
Params(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w)
|
||||
: pointer(ptr) {
|
||||
Base::Params::initialize(stride_d, stride_h, stride_w);
|
||||
}
|
||||
@ -557,7 +563,7 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
|
||||
/// Initializes params to access a raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w) {
|
||||
int initialize(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w) {
|
||||
Base::Params::initialize(stride_d, stride_h, stride_w);
|
||||
pointer = ptr;
|
||||
return 0;
|
||||
@ -566,10 +572,10 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar const *ptr,
|
||||
long long _stride_d,
|
||||
Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
long long _inc_d,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
@ -720,7 +726,7 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
}
|
||||
|
||||
/// Adds a raw offset to the pointer
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset) { params.pointer += offset; }
|
||||
|
||||
CUTLASS_HOST_DEVICE Index stride_advance(void) {
|
||||
Index stride = params.stride_h;
|
||||
@ -734,7 +740,6 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
@ -876,6 +881,9 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
/// Index type
|
||||
typedef typename Base::Index Index;
|
||||
|
||||
/// Long index type
|
||||
typedef typename Base::LongIndex LongIndex;
|
||||
|
||||
/// Skew quantity
|
||||
typedef typename Base::Skew Skew;
|
||||
|
||||
@ -953,10 +961,10 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar *ptr,
|
||||
long long _stride_d,
|
||||
Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
long long _inc_d,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
@ -979,7 +987,7 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
|
||||
/// Initializes params to access a raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w) {
|
||||
int initialize(Scalar *ptr, Index stride_d, Index stride_h, Index stride_w) {
|
||||
Base::Params::initialize(stride_d, stride_h, stride_w);
|
||||
pointer = ptr;
|
||||
return 0;
|
||||
@ -988,10 +996,10 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar *ptr,
|
||||
long long _stride_d,
|
||||
Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
long long _inc_d,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
@ -1121,7 +1129,7 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
}
|
||||
|
||||
/// Adds a raw offset to the pointer
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset) { params.pointer += offset; }
|
||||
|
||||
/// Stores a single fragment element into memory.
|
||||
CUTLASS_HOST_DEVICE void store_element(AccessType const &value, int d, int h, int w, int c) {
|
||||
|
||||
124
cutlass/util/pair.h
Normal file
124
cutlass/util/pair.h
Normal file
@ -0,0 +1,124 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Defines a pair<>
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace cutlass {
|
||||
namespace platform {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constructs an iterator from a pair of iterators
|
||||
template <typename T1, typename T2>
|
||||
struct Pair {
|
||||
|
||||
typedef T1 first_type;
|
||||
typedef T2 second_type;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
T1 first;
|
||||
T1 second;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Pair() { }
|
||||
|
||||
/// Constructs a pair
|
||||
CUTLASS_HOST_DEVICE
|
||||
Pair(T1 const &first_, T2 const &second_): first(first_), second(second_) { }
|
||||
};
|
||||
|
||||
/// Constructs a pair and deduces types
|
||||
template <typename T1, typename T2>
|
||||
Pair<T1, T2> make_Pair(T1 const &first, T2 const &second) {
|
||||
return Pair<T1, T2>(first, second);
|
||||
}
|
||||
|
||||
/// Equality
|
||||
template <typename T1, typename T2>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
|
||||
return (lhs.first == rhs.first) && (lhs.second == rhs.second);
|
||||
}
|
||||
|
||||
/// Inequality
|
||||
template <typename T1, typename T2>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
/// Lexical comparison
|
||||
template <typename T1, typename T2>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator<(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
|
||||
if (lhs.first < rhs.first) {
|
||||
return true;
|
||||
}
|
||||
else if (rhs.first < lhs.first) {
|
||||
return false;
|
||||
}
|
||||
else if (rhs.second < rhs.second) {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Lexical comparison
|
||||
template <typename T1, typename T2>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator<=(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
|
||||
return !(rhs < lhs);
|
||||
}
|
||||
|
||||
/// Lexical comparison
|
||||
template <typename T1, typename T2>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator>(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
|
||||
return (rhs < lhs);
|
||||
}
|
||||
|
||||
/// Lexical comparison
|
||||
template <typename T1, typename T2>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator>=(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
|
||||
return !(lhs < rhs);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace platform
|
||||
} // namespace cutlass
|
||||
40
cutlass/util/performance_tuning.h
Normal file
40
cutlass/util/performance_tuning.h
Normal file
@ -0,0 +1,40 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are not permitted.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
#ifndef CUTLASS_PERFORMANCE_TUNING_H
|
||||
#define CUTLASS_PERFORMANCE_TUNING_H
|
||||
|
||||
// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if defined(_MSC_VER)
|
||||
#define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL __pragma("unroll 1")
|
||||
#else
|
||||
#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
|
||||
#endif
|
||||
#else
|
||||
#define CUTLASS_PRAGMA_UNROLL
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL
|
||||
#endif
|
||||
|
||||
#define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
|
||||
#endif // CUTLASS_PERFORMANCE_TUNING_H
|
||||
@ -32,6 +32,7 @@
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/zip_tensor_ref.h"
|
||||
#include "cutlass/zip_fragment.h"
|
||||
#include "cutlass/util/pair.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -72,7 +73,10 @@ class ZipTileIterator {
|
||||
typedef typename First::PredicateVector PredicateVector;
|
||||
|
||||
/// Index type
|
||||
typedef typename First::Index Index;
|
||||
typedef platform::Pair<typename First::Index, typename Second::Index> Index;
|
||||
|
||||
/// Long index type
|
||||
typedef platform::Pair<typename First::LongIndex, typename Second::LongIndex> LongIndex;
|
||||
|
||||
/// Tensor reference
|
||||
typedef ZipTensorRef<
|
||||
@ -276,9 +280,9 @@ class ZipTileIterator {
|
||||
CUTLASS_DEVICE ZipTileIterator &operator-=(int count) { return decrement(count); }
|
||||
|
||||
/// Adds an offset to both iterators
|
||||
CUTLASS_DEVICE void add_pointer_offset(Index offset) {
|
||||
first.add_pointer_offset(offset);
|
||||
second.add_pointer_offset(offset);
|
||||
CUTLASS_DEVICE void add_pointer_offset(LongIndex offset) {
|
||||
first.add_pointer_offset(offset.first);
|
||||
second.add_pointer_offset(offset.second);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -103,6 +103,7 @@
|
||||
// Defines cutlass::reference::host::Gemm()
|
||||
#include "tools/util/reference/host/gemm.h"
|
||||
|
||||
#pragma warning( disable : 4503)
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Define a CUTLASS GEMM template and launch a GEMM kernel.
|
||||
@ -144,18 +145,18 @@ cudaError_t Cutlass_FP16_SgemmNN(
|
||||
typename Gemm::Params params;
|
||||
|
||||
int result = params.initialize(
|
||||
M, // GEMM M dimension
|
||||
N, // GEMM N dimension
|
||||
K, // GEMM K dimension
|
||||
reinterpret_cast<half const &>(alpha), // scalar alpha - This is a legal conversion from cutlass::half_t to CUDA's half.
|
||||
A, // matrix A operand
|
||||
M, // GEMM M dimension
|
||||
N, // GEMM N dimension
|
||||
K, // GEMM K dimension
|
||||
reinterpret_cast<half const &>(alpha), // scalar alpha
|
||||
A, // matrix A operand
|
||||
lda,
|
||||
B, // matrix B operand
|
||||
B, // matrix B operand
|
||||
ldb,
|
||||
reinterpret_cast<half const &>(beta), // scalar beta - This is a legal conversion from cutlass::half_t to CUDA's half.
|
||||
C, // source matrix C
|
||||
reinterpret_cast<half const &>(beta), // scalar beta
|
||||
C, // source matrix C
|
||||
ldc,
|
||||
C, // destination matrix C (may be different memory than source C matrix)
|
||||
C, // destination matrix C (may be different memory than source C matrix)
|
||||
ldc
|
||||
);
|
||||
|
||||
|
||||
38
examples/06_splitK_gemm/CMakeLists.txt
Normal file
38
examples/06_splitK_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(EXAMPLES_SPLITK_GEMM_SOURCES
|
||||
splitK_gemm.cu
|
||||
)
|
||||
|
||||
if (NOT CUTLASS_NATIVE_CUDA)
|
||||
# cuda_add_executable does not take interface include directories into account
|
||||
# Let's fetch them and pass them to CUDA.
|
||||
get_target_property(CUTLASS_INCLUDES CUTLASS INTERFACE_INCLUDE_DIRECTORIES)
|
||||
include_directories("${CUTLASS_INCLUDES}")
|
||||
endif()
|
||||
|
||||
cutlass_add_executable(
|
||||
06_splitK_gemm
|
||||
${EXAMPLES_SPLITK_GEMM_SOURCES}
|
||||
)
|
||||
302
examples/06_splitK_gemm/splitK_gemm.cu
Normal file
302
examples/06_splitK_gemm/splitK_gemm.cu
Normal file
@ -0,0 +1,302 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device_gemm.h"
|
||||
#include "cutlass/gemm/sgemm_traits.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "cutlass/gemm/device_gemm_traits.h"
|
||||
#pragma warning( disable : 4503)
|
||||
/*
|
||||
This example demonstrates how to use cutlass to compute sgemm with splitK
|
||||
splitK is useful for gemm with small M and N and reasonably large K.
|
||||
Because the sizes of M and N are small, the number of threadblocks we can launch is often limited and
|
||||
results in under utilization of the hardware.
|
||||
splitK allows us to divide a gemm across K dimension by first launching a partitionedK gemm (very similar to batched gemm),
|
||||
storing the intermediate result in workspace and then launching a second reduction kernel.
|
||||
Thus, as demonstrated by function cutlass_splitK_sgemm_nn(), the users need to create two traits, one for the partitionedK gemm,
|
||||
and one for the reduction. The users are also responsible for allocating and releasing the workspace memory. The size of the workspace
|
||||
memory can be queried by calling required_workspace_memory_in_byte().
|
||||
*/
|
||||
|
||||
template<int splits_count>
|
||||
cudaError_t cutlass_splitK_sgemm_nn(float const *A,
|
||||
int lda,
|
||||
float const *B,
|
||||
int ldb,
|
||||
float *C,
|
||||
int ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
int m,
|
||||
int n,
|
||||
int k) {
|
||||
cudaError_t result = cudaSuccess;
|
||||
|
||||
// create cutlass gemm traits for the first kernel
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor, /*the layout of A*/
|
||||
cutlass::MatrixLayout::kColumnMajor, /*the layout of B*/
|
||||
cutlass::Shape<8, 128, 128> > /*the tile for each threadblock*/
|
||||
SgemmTraits;
|
||||
|
||||
// create cutlass batched reduction traits for the second kernel
|
||||
// for reduction D = alpha * Reduction(A) + beta * C
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*the scalar type of A in reduction, not to be confused with A in GEMM*/
|
||||
float, /*the scalar type of C in reduction, not to be confused with C in GEMM*/
|
||||
float, /*the scalar type of D in reduction, not to be confused with D in GEMM*/
|
||||
float, /*the scalar type of alpha and beta in reduction*/
|
||||
float, /*the scalar type of accumulation in reduction*/
|
||||
splits_count /*reduction workload*/
|
||||
>
|
||||
BatchedReductionTraits;
|
||||
|
||||
// create a device gemm that packages gemm traits and batched reduction traits
|
||||
typedef cutlass::gemm::SplitkPIGemmTraits<SgemmTraits, BatchedReductionTraits> deviceGemmTraits;
|
||||
|
||||
// kernel class
|
||||
typedef typename deviceGemmTraits::KernelClass deviceGemm;
|
||||
|
||||
// Params ctor requires M, N, K sizes
|
||||
typename deviceGemm::Params deviceGemmParams(m, n, k);
|
||||
|
||||
// query if workspace is needed. the workspace size is sizeof(accumulateType) * M * N * splits_count
|
||||
int workspace_size = deviceGemmParams.required_workspace_memory_in_byte();
|
||||
if (workspace_size <= 0) {
|
||||
std::cerr << "splitK workspace_size is smaller than 0" << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
|
||||
// allocate workspace memory
|
||||
float *workspace_ptr;
|
||||
result = cudaMalloc(&workspace_ptr, workspace_size);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// finish init Params
|
||||
deviceGemmParams.initialize(alpha, /*alpha*/
|
||||
A, /*A*/
|
||||
lda, /*lda*/
|
||||
B, /*B*/
|
||||
ldb, /*ldb*/
|
||||
beta, /*beta*/
|
||||
C, /*C*/
|
||||
ldc, /*ldc*/
|
||||
C, /*D, can point to the same memory with C*/
|
||||
ldc, /*ldc*/
|
||||
workspace_ptr /*ptr to workspace*/
|
||||
);
|
||||
|
||||
// launch the kernel
|
||||
deviceGemm::launch(deviceGemmParams);
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "launch result = " << result << std::endl;
|
||||
cudaFree(workspace_ptr);
|
||||
return result;
|
||||
}
|
||||
|
||||
// release the workspace memory
|
||||
result = cudaFree(workspace_ptr);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaFree result = " << result << std::endl;
|
||||
}
|
||||
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
cudaError_t sgemm_nn_reference(std::vector<T> const &A,
|
||||
int lda,
|
||||
std::vector<T> const &B,
|
||||
int ldb,
|
||||
std::vector<T> &C,
|
||||
int ldc,
|
||||
T alpha,
|
||||
T beta,
|
||||
int m,
|
||||
int n,
|
||||
int k) {
|
||||
/*
|
||||
sgemm
|
||||
*/
|
||||
|
||||
cudaError_t result = cudaSuccess;
|
||||
for (int n_idx = 0; n_idx < n; n_idx++) {
|
||||
for (int m_idx = 0; m_idx < m; m_idx++) {
|
||||
T accum = beta * C[n_idx * ldc + m_idx];
|
||||
for (int k_idx = 0; k_idx < k; k_idx++) {
|
||||
accum += alpha
|
||||
* A[k_idx * lda + m_idx]
|
||||
* B[n_idx * ldb + k_idx];
|
||||
}
|
||||
C[n_idx * ldc + m_idx] = accum;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
int main() {
|
||||
int const m = 128;
|
||||
int const n = 128;
|
||||
int const k = 4096;
|
||||
//splits_count should be known at compile time
|
||||
int const splits_count = 80;
|
||||
|
||||
// A, B are non-transpose, column major
|
||||
int const lda = m;
|
||||
int const ldb = k;
|
||||
int const ldc = m;
|
||||
|
||||
int const count_A = lda * k;
|
||||
int const count_B = ldb * n;
|
||||
int const count_C = ldc * n;
|
||||
|
||||
// alpha and beta
|
||||
float alpha = 1.0f;
|
||||
float beta = 2.0f;
|
||||
|
||||
cudaError_t result = cudaSuccess;
|
||||
|
||||
// allocate the host memory
|
||||
std::vector<float> host_A(count_A);
|
||||
std::vector<float> host_B(count_B);
|
||||
std::vector<float> host_C(count_C);
|
||||
std::vector<float> result_C(count_C);
|
||||
|
||||
// allocate the device memory
|
||||
float *A;
|
||||
float *B;
|
||||
float *C;
|
||||
|
||||
result = cudaMalloc(&A, count_A * sizeof(float));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMalloc(&B, count_B * sizeof(float));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMalloc(&C, count_C * sizeof(float));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// fill A
|
||||
for (int col_idx = 0; col_idx < k; col_idx++) {
|
||||
for (int row_idx = 0; row_idx < m; row_idx++) {
|
||||
host_A[row_idx + col_idx * lda] = static_cast<float>((row_idx + col_idx) % 10);
|
||||
}
|
||||
}
|
||||
|
||||
// fill B
|
||||
for (int col_idx = 0; col_idx < n; col_idx++) {
|
||||
for (int row_idx = 0; row_idx < k; row_idx++) {
|
||||
host_B[row_idx + col_idx * ldb] = static_cast<float>((row_idx - col_idx) % 5);
|
||||
}
|
||||
}
|
||||
|
||||
// fill C
|
||||
for (int col_idx = 0; col_idx < n; col_idx++) {
|
||||
for (int row_idx = 0; row_idx < m; row_idx++) {
|
||||
host_C[row_idx + col_idx * ldc] = 1.f;
|
||||
}
|
||||
}
|
||||
|
||||
// ref memory
|
||||
std::vector<float> ref_A(host_A);
|
||||
std::vector<float> ref_B(host_B);
|
||||
std::vector<float> ref_C(host_C);
|
||||
// copy host memory to device
|
||||
result = cudaMemcpy(A, host_A.data(), count_A * sizeof(float), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMemcpy(B, host_B.data(), count_B * sizeof(float), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMemcpy(C, host_C.data(), count_C * sizeof(float), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// run cutlass
|
||||
result = cutlass_splitK_sgemm_nn<splits_count>(A, lda, B, ldb, C, ldc, alpha, beta, m, n, k);
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
|
||||
// copy device memory to host
|
||||
result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
//compare with reference code
|
||||
result = sgemm_nn_reference(ref_A, lda, ref_B, ldb, ref_C, ldc, alpha, beta, m, n, k);
|
||||
if (result != 0)
|
||||
return result;
|
||||
|
||||
if (ref_C != result_C) {
|
||||
std::cout << "CUTLASS splitK gemm does not run correctly" << std::endl;
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
// free memory
|
||||
result = cudaFree(A);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaFree result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaFree(B);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaFree result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaFree(C);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaFree result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
}
|
||||
|
||||
// Exit.
|
||||
return result == cudaSuccess ? 0 : -1;
|
||||
}
|
||||
@ -26,3 +26,4 @@ add_subdirectory(02_cutlass_utilities)
|
||||
add_subdirectory(03_strided_batched_gemm)
|
||||
add_subdirectory(04_tile_iterator)
|
||||
add_subdirectory(05_wmma_gemm)
|
||||
add_subdirectory(06_splitK_gemm)
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 89 KiB After Width: | Height: | Size: 89 KiB |
@ -21,6 +21,8 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include_directories("external/googletest/googletest/include")
|
||||
|
||||
add_subdirectory(external/googletest/googletest)
|
||||
add_subdirectory(test)
|
||||
add_subdirectory(nvrtc)
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ set(CUTLASS_PERF_TEST_HEADERS
|
||||
performance_result.h
|
||||
gemm/cublas_dispatch.h
|
||||
gemm/cutlass_dispatch.h
|
||||
gemm/cutlass_dispatch_splitK_PI.h
|
||||
gemm/gemm_perf_testbed.h
|
||||
gemm/gemm_profiler.h
|
||||
)
|
||||
@ -36,9 +37,11 @@ set(CUTLASS_PERF_TEST_HEADERS
|
||||
set(CUTLASS_PERF_TEST_SOURCES
|
||||
cutlass_perf_test.cu
|
||||
gemm/sgemm.cu
|
||||
gemm/sgemm_splitK.cu
|
||||
gemm/dgemm.cu
|
||||
gemm/hgemm.cu
|
||||
gemm/igemm.cu
|
||||
gemm/igemm_splitK.cu
|
||||
gemm/wmma_gemm.cu
|
||||
gemm/wmma_binary_gemm.cu
|
||||
gemm/wmma_integer_gemm.cu
|
||||
|
||||
@ -1,121 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/// \file {nv-internal-release}
|
||||
|
||||
#if (defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750))
|
||||
#pragma warning( disable : 4503)
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/bmma_gemm_traits.h"
|
||||
#include "tools/test/perf/cutlass_perf_test.h"
|
||||
#include "tools/test/perf/gemm/gemm_profiler.h"
|
||||
#include "tools/test/perf/gemm/cutlass_dispatch.h"
|
||||
#include "tools/test/perf/gemm/gemm_perf_testbed.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Traits>
|
||||
struct BmmaGemmDispatch {
|
||||
|
||||
typedef cutlass::gemm::Gemm<Traits> Gemm;
|
||||
|
||||
typedef typename Gemm::Params Params;
|
||||
|
||||
/// Indicate warp-level GEMM
|
||||
static bool const kThreadMultiplyAdd = false;
|
||||
|
||||
static bool const kRunCuBLAS = false;
|
||||
|
||||
static cutlass::MatrixLayout::Kind const kLayoutA = Traits::kLayoutA;
|
||||
static cutlass::MatrixLayout::Kind const kLayoutB = Traits::kLayoutB;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Params argument
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
BmmaGemmDispatch() {}
|
||||
|
||||
/// Initializes params object
|
||||
BmmaGemmDispatch(int m, int n, int k, int alpha,
|
||||
cutlass::Vector<cutlass::bin1_t, 32> const* d_a, int lda,
|
||||
cutlass::Vector<cutlass::bin1_t, 32> const* d_b, int ldb, int beta,
|
||||
int const* d_c, int ldc, int* d_d, int ldd) {
|
||||
|
||||
params.initialize(m, n, k * 32, alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd);
|
||||
}
|
||||
|
||||
/// Initializes params object
|
||||
BmmaGemmDispatch(Params const& _params) : params(_params) {}
|
||||
|
||||
/// Launches kernel
|
||||
cudaError_t operator()() { return Gemm::launch(params); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace perf {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int profile_bmma_gemm(TestbenchOutput<GemmProblem> &output, TestbenchOptions const &options, Config const &config) {
|
||||
typedef perf::GemmProfiler<cutlass::Vector<cutlass::bin1_t, 32>, cutlass::Vector<cutlass::bin1_t, 32>, int, int, int> GemmProfiler;
|
||||
|
||||
int results = 0;
|
||||
|
||||
{
|
||||
|
||||
typedef cutlass::gemm::BmmaGemmTraits<cutlass::Shape<1024, 128, 128>,
|
||||
cutlass::Shape<1024, 32, 32>,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor>
|
||||
BmmaGemmTraits;
|
||||
|
||||
typedef BmmaGemmDispatch<BmmaGemmTraits> Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "bmma_gemm_tn", options, config);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct BmmaGemmRegistrar {
|
||||
BmmaGemmRegistrar() { RegisterGemmProfileFunc(profile_bmma_gemm); }
|
||||
};
|
||||
|
||||
volatile BmmaGemmRegistrar _BmmaGemmRegistrar;
|
||||
|
||||
} // namespace perf
|
||||
|
||||
#endif // if (defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750)
|
||||
@ -89,4 +89,76 @@ struct CublasGemmDispatch {
|
||||
}
|
||||
};
|
||||
|
||||
/// Dispatcher for batched strided cuBLAS kernels
|
||||
template <typename AType, typename BType, typename CType, typename Accumulator, typename Scalar>
|
||||
struct CublasBatchedStridedGemmDispatch {
|
||||
/// Type used for device-side allocations
|
||||
typedef typename cutlass::TypeTraits<AType>::device_type ADeviceType;
|
||||
typedef typename cutlass::TypeTraits<BType>::device_type BDeviceType;
|
||||
typedef typename cutlass::TypeTraits<CType>::device_type CDeviceType;
|
||||
typedef typename cutlass::TypeTraits<Accumulator>::device_type AccumulatorDeviceType;
|
||||
typedef typename cutlass::TypeTraits<Scalar>::device_type ScalarDeviceType;
|
||||
|
||||
static cublasOperation_t convert(cutlass::MatrixLayout::Kind layout) {
|
||||
switch (layout) {
|
||||
case cutlass::MatrixLayout::kRowMajor:
|
||||
return CUBLAS_OP_T;
|
||||
case cutlass::MatrixLayout::kColumnMajor:
|
||||
return CUBLAS_OP_N;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return CUBLAS_OP_N;
|
||||
}
|
||||
|
||||
/// Launches a cuBLAS GEMM kernel
|
||||
cublasStatus_t operator()(cublasHandle_t handle,
|
||||
cutlass::MatrixLayout::Kind layout_a,
|
||||
cutlass::MatrixLayout::Kind layout_b,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
Scalar alpha,
|
||||
const ADeviceType *A,
|
||||
int lda,
|
||||
long long int batch_stride_A,
|
||||
const BDeviceType *B,
|
||||
int ldb,
|
||||
long long int batch_stride_B,
|
||||
Scalar beta,
|
||||
CDeviceType *C,
|
||||
int ldc,
|
||||
long long int batch_stride_C,
|
||||
int batch_count,
|
||||
cublasGemmAlgo_t algorithm) {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 9010
|
||||
return cublasGemmStridedBatchedEx(handle,
|
||||
convert(layout_a),
|
||||
convert(layout_b),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
reinterpret_cast<ScalarDeviceType const *>(&alpha),
|
||||
A,
|
||||
cutlass::TypeTraits<ADeviceType>::cublas_type,
|
||||
lda,
|
||||
batch_stride_A,
|
||||
B,
|
||||
cutlass::TypeTraits<BDeviceType>::cublas_type,
|
||||
ldb,
|
||||
batch_stride_B,
|
||||
reinterpret_cast<ScalarDeviceType const *>(&beta),
|
||||
C,
|
||||
cutlass::TypeTraits<CDeviceType>::cublas_type,
|
||||
ldc,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
cutlass::TypeTraits<AccumulatorDeviceType>::cublas_type,
|
||||
algorithm);
|
||||
#else
|
||||
return CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace perf
|
||||
|
||||
@ -81,6 +81,32 @@ struct CutlassDispatch {
|
||||
params.initialize(m, n, k, alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd);
|
||||
}
|
||||
|
||||
/// Initializes batched strided params object
|
||||
CutlassDispatch(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
ScalarEpilogue alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
long long int batch_stride_A,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
long long int batch_stride_B,
|
||||
ScalarEpilogue beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
long long int batch_stride_C,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
long long int batch_stride_D,
|
||||
Index batch_count) {
|
||||
params.initialize(m, n, k, alpha, d_a, lda, batch_stride_A,
|
||||
d_b, ldb, batch_stride_B,
|
||||
beta, d_c, ldc, batch_stride_C,
|
||||
d_d, ldd, batch_stride_D,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
/// Initializes params object
|
||||
CutlassDispatch(Params const& _params) : params(_params) {}
|
||||
|
||||
|
||||
172
tools/test/perf/gemm/cutlass_dispatch_splitK_PI.h
Normal file
172
tools/test/perf/gemm/cutlass_dispatch_splitK_PI.h
Normal file
@ -0,0 +1,172 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "tools/util/type_traits.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <assert.h>
|
||||
|
||||
namespace perf {
|
||||
|
||||
template <typename KernelClass_,
|
||||
typename Index_,
|
||||
typename ScalarA_,
|
||||
typename ScalarB_,
|
||||
typename ScalarC_,
|
||||
typename ScalarD_,
|
||||
typename Compute_,
|
||||
typename ScalarEpilogue_,
|
||||
bool ThreadMultiplyAdd_,
|
||||
bool RunCuBLAS_ = true>
|
||||
struct CutlassDispatchSplitKPIGemm {
|
||||
typedef typename KernelClass_::Params Params;
|
||||
typedef KernelClass_ KernelClass;
|
||||
typedef Index_ Index;
|
||||
typedef ScalarA_ ScalarA;
|
||||
typedef ScalarB_ ScalarB;
|
||||
typedef ScalarC_ ScalarC;
|
||||
typedef ScalarD_ ScalarD;
|
||||
typedef Compute_ Compute;
|
||||
typedef ScalarEpilogue_ ScalarEpilogue;
|
||||
|
||||
static bool const kThreadMultiplyAdd = ThreadMultiplyAdd_;
|
||||
static bool const kRunCuBLAS = RunCuBLAS_;
|
||||
|
||||
static cutlass::MatrixLayout::Kind const kLayoutA = KernelClass::Traits::kLayoutA;
|
||||
static cutlass::MatrixLayout::Kind const kLayoutB = KernelClass::Traits::kLayoutB;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Params argument
|
||||
Params params;
|
||||
|
||||
/// splitK PI require workspace
|
||||
typename cutlass::TypeTraits<Compute>::device_type *workspace_ptr;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor Initializes params object
|
||||
CutlassDispatchSplitKPIGemm(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
ScalarEpilogue alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
ScalarEpilogue beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd) {
|
||||
params.init_problem(m, n, k);
|
||||
int workspace_size_in_byte = params.required_workspace_memory_in_byte();
|
||||
|
||||
cudaError_t workspace_err = cudaMalloc(&workspace_ptr, workspace_size_in_byte);
|
||||
if (workspace_err != cudaSuccess) {
|
||||
std::cout << "\nCUDA workspace malloc error: " << cudaGetErrorString(workspace_err)
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
params.initialize(alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd, workspace_ptr);
|
||||
}
|
||||
|
||||
/// Initializes batched strided params object
|
||||
CutlassDispatchSplitKPIGemm(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
ScalarEpilogue alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
long long int batch_stride_A,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
long long int batch_stride_B,
|
||||
ScalarEpilogue beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
long long int batch_stride_C,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
long long int batch_stride_D,
|
||||
Index batch_count) {
|
||||
assert(0);//batched strided splitK should never be called
|
||||
}
|
||||
|
||||
/// Launches kernel
|
||||
cudaError_t operator()() { return KernelClass::launch(params); }
|
||||
|
||||
~CutlassDispatchSplitKPIGemm() {
|
||||
cudaError_t workspace_err = cudaFree(workspace_ptr);
|
||||
if (workspace_err != cudaSuccess) {
|
||||
std::cout << "\nCUDA workspace malloc error: " << cudaGetErrorString(workspace_err)
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<
|
||||
typename SplitKPIGemmTraits_
|
||||
>
|
||||
struct CutlassDispatchSplitKPIGemmBasic {
|
||||
///
|
||||
typedef SplitKPIGemmTraits_ Traits;
|
||||
|
||||
///
|
||||
typedef typename Traits::KernelClass KernelClass;
|
||||
|
||||
/// Index type
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// The scalar for A.
|
||||
typedef typename Traits::ScalarA ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef typename Traits::ScalarB ScalarB;
|
||||
/// The scalar for C.
|
||||
typedef typename Traits::ScalarC ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename Traits::ScalarD ScalarD;
|
||||
|
||||
// TODO - support alternative accumulator and scalar types
|
||||
typedef ScalarD Compute;
|
||||
typedef Compute ScalarEpilogue;
|
||||
|
||||
typedef CutlassDispatchSplitKPIGemm<KernelClass,
|
||||
Index,
|
||||
ScalarA,
|
||||
ScalarB,
|
||||
ScalarC,
|
||||
ScalarD,
|
||||
Compute,
|
||||
ScalarEpilogue,
|
||||
true>
|
||||
Dispatch;
|
||||
};
|
||||
} //namespace perf
|
||||
@ -78,6 +78,7 @@ class GemmTestbed {
|
||||
|
||||
/// Dispatch object to cuBLAS GEMM
|
||||
typedef CublasGemmDispatch<AType, BType, CType, Accumulator, Scalar> CublasDispatch;
|
||||
typedef CublasBatchedStridedGemmDispatch<AType, BType, CType, Accumulator, Scalar> CublasBatchedStridedGemmDispatch;
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
@ -160,18 +161,20 @@ class GemmTestbed {
|
||||
|
||||
/// Resizes each tensor
|
||||
void resize_helper(GemmProblem const &problem) {
|
||||
resize_device_allocation(A,
|
||||
initial_distribution.dist_A,
|
||||
initial_distribution.seed,
|
||||
problem.m,
|
||||
problem.k,
|
||||
problem.layout_A);
|
||||
|
||||
resize_device_allocation(A,
|
||||
initial_distribution.dist_A,
|
||||
initial_distribution.seed,
|
||||
problem.m,
|
||||
problem.k * problem.batch_count,
|
||||
problem.layout_A);
|
||||
|
||||
|
||||
resize_device_allocation(
|
||||
B,
|
||||
initial_distribution.dist_B,
|
||||
initial_distribution.seed + 17, // compute distinct value from initial seed
|
||||
problem.k,
|
||||
problem.k * problem.batch_count,
|
||||
problem.n,
|
||||
problem.layout_B);
|
||||
|
||||
@ -180,21 +183,21 @@ class GemmTestbed {
|
||||
initial_distribution.dist_C,
|
||||
initial_distribution.seed + 101, // compute distinct value from initial seed
|
||||
problem.m,
|
||||
problem.n,
|
||||
problem.n * problem.batch_count,
|
||||
cutlass::MatrixLayout::kColumnMajor);
|
||||
|
||||
resize_device_allocation(reference,
|
||||
cutlass::Distribution(),
|
||||
0,
|
||||
problem.m,
|
||||
problem.n,
|
||||
problem.n * problem.batch_count,
|
||||
cutlass::MatrixLayout::kColumnMajor);
|
||||
|
||||
resize_device_allocation(experimental,
|
||||
cutlass::Distribution(),
|
||||
0,
|
||||
problem.m,
|
||||
problem.n,
|
||||
problem.n * problem.batch_count,
|
||||
cutlass::MatrixLayout::kColumnMajor);
|
||||
}
|
||||
|
||||
@ -315,24 +318,36 @@ class GemmTestbed {
|
||||
/// Inner dimension of GEMM problem
|
||||
int K() const { return problem.k; }
|
||||
|
||||
/// batch count
|
||||
int batch_count() const { return problem.batch_count; }
|
||||
|
||||
/// Returns a pointer to the A operand
|
||||
ADeviceType *ptr_A() const { return A.get(); }
|
||||
|
||||
/// Leading dimension of A
|
||||
int lda() const { return problem.lda(); }
|
||||
|
||||
///
|
||||
long long int batch_stride_a() const{ return problem.batch_stride_a(); }
|
||||
|
||||
/// Returns a pointer to the B operand
|
||||
BDeviceType *ptr_B() const { return B.get(); }
|
||||
|
||||
/// Leading dimension of B
|
||||
int ldb() const { return problem.ldb(); }
|
||||
|
||||
///
|
||||
long long int batch_stride_b() const{ return problem.batch_stride_b(); }
|
||||
|
||||
/// Returns a pointer to the initial state of the result tensor in device memory
|
||||
CDeviceType *ptr_C_initial() const { return C_initial.get(); }
|
||||
|
||||
/// Leading dimension of C
|
||||
int ldc() const { return problem.ldc(); }
|
||||
|
||||
///
|
||||
long long int batch_stride_c() const { return problem.batch_stride_c(); }
|
||||
|
||||
/// Returns a pointer to the result tensor in device memory
|
||||
CDeviceType *ptr_experimental() const { return experimental.get(); }
|
||||
|
||||
@ -341,7 +356,7 @@ class GemmTestbed {
|
||||
|
||||
/// Returns the number of flops implied by the computation (1 multiply-accumulate = 2 flops)
|
||||
uint64_t flops() const {
|
||||
return uint64_t(problem.m) * uint64_t(problem.n) * uint64_t(problem.k) * detail::ElementCount<AType>::kValue * 2ULL;
|
||||
return uint64_t(problem.batch_count) * uint64_t(problem.m) * uint64_t(problem.n) * uint64_t(problem.k) * detail::ElementCount<AType>::kValue * 2ULL;
|
||||
}
|
||||
|
||||
/// Computes the speed of the computation in GFLOPs/s
|
||||
@ -373,28 +388,59 @@ class GemmTestbed {
|
||||
|
||||
/// Launches the cuBLAS GEMM - does not initialize output matrix
|
||||
cublasStatus_t launch_cublas(cublasGemmAlgo_t algo) {
|
||||
CublasDispatch dispatch;
|
||||
if (problem.batch_count == 1) {
|
||||
CublasDispatch dispatch;
|
||||
|
||||
Scalar alpha(Scalar(problem.alpha));
|
||||
Scalar beta(Scalar(problem.beta));
|
||||
Scalar alpha(Scalar(problem.alpha));
|
||||
Scalar beta(Scalar(problem.beta));
|
||||
|
||||
status = dispatch(handle,
|
||||
problem.layout_A,
|
||||
problem.layout_B,
|
||||
problem.m,
|
||||
problem.n,
|
||||
problem.k,
|
||||
alpha,
|
||||
ptr_A(),
|
||||
lda(),
|
||||
ptr_B(),
|
||||
ldb(),
|
||||
beta,
|
||||
ptr_reference(),
|
||||
ldc(),
|
||||
algo);
|
||||
status = dispatch(handle,
|
||||
problem.layout_A,
|
||||
problem.layout_B,
|
||||
problem.m,
|
||||
problem.n,
|
||||
problem.k,
|
||||
alpha,
|
||||
ptr_A(),
|
||||
lda(),
|
||||
ptr_B(),
|
||||
ldb(),
|
||||
beta,
|
||||
ptr_reference(),
|
||||
ldc(),
|
||||
algo);
|
||||
|
||||
return status;
|
||||
return status;
|
||||
}
|
||||
else {
|
||||
// call batched strided cublas
|
||||
CublasBatchedStridedGemmDispatch dispatch;
|
||||
|
||||
Scalar alpha(Scalar(problem.alpha));
|
||||
Scalar beta(Scalar(problem.beta));
|
||||
|
||||
status = dispatch(handle,
|
||||
problem.layout_A,
|
||||
problem.layout_B,
|
||||
problem.m,
|
||||
problem.n,
|
||||
problem.k,
|
||||
alpha,
|
||||
ptr_A(),
|
||||
lda(),
|
||||
batch_stride_a(),
|
||||
ptr_B(),
|
||||
ldb(),
|
||||
batch_stride_b(),
|
||||
beta,
|
||||
ptr_reference(),
|
||||
ldc(),
|
||||
batch_stride_c(),
|
||||
batch_count(),
|
||||
algo);
|
||||
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
/// Verifies the 'test' tensor with 'ref'
|
||||
|
||||
@ -164,24 +164,52 @@ class GemmProfiler {
|
||||
result.disposition = Disposition::Passed;
|
||||
}
|
||||
|
||||
CutlassDispatch dispatch(testbed.M(),
|
||||
testbed.N(),
|
||||
testbed.K(),
|
||||
testbed.alpha(),
|
||||
testbed.ptr_A(),
|
||||
testbed.lda(),
|
||||
testbed.ptr_B(),
|
||||
testbed.ldb(),
|
||||
testbed.beta(),
|
||||
testbed.ptr_C_initial(),
|
||||
testbed.ldc(),
|
||||
testbed.ptr_experimental(),
|
||||
testbed.ldc());
|
||||
CutlassDispatch *dispatch_ptr;
|
||||
|
||||
dispatch();
|
||||
// check to see if we need to launch batched strided gemm
|
||||
if (testbed.batch_count() == 1) {
|
||||
dispatch_ptr = new CutlassDispatch(testbed.M(),
|
||||
testbed.N(),
|
||||
testbed.K(),
|
||||
testbed.alpha(),
|
||||
testbed.ptr_A(),
|
||||
testbed.lda(),
|
||||
testbed.ptr_B(),
|
||||
testbed.ldb(),
|
||||
testbed.beta(),
|
||||
testbed.ptr_C_initial(),
|
||||
testbed.ldc(),
|
||||
testbed.ptr_experimental(),
|
||||
testbed.ldc());
|
||||
|
||||
dispatch_ptr->operator()();
|
||||
}
|
||||
else {
|
||||
dispatch_ptr = new CutlassDispatch(testbed.M(),
|
||||
testbed.N(),
|
||||
testbed.K(),
|
||||
testbed.alpha(),
|
||||
testbed.ptr_A(),
|
||||
testbed.lda(),
|
||||
testbed.batch_stride_a(),
|
||||
testbed.ptr_B(),
|
||||
testbed.ldb(),
|
||||
testbed.batch_stride_b(),
|
||||
testbed.beta(),
|
||||
testbed.ptr_C_initial(),
|
||||
testbed.ldc(),
|
||||
testbed.batch_stride_c(),
|
||||
testbed.ptr_experimental(),
|
||||
testbed.ldc(),
|
||||
testbed.batch_stride_c(),
|
||||
testbed.batch_count());
|
||||
|
||||
dispatch_ptr->operator()();
|
||||
}
|
||||
|
||||
if (cudaDeviceSynchronize() != cudaSuccess) {
|
||||
result.disposition = Disposition::Failed;
|
||||
delete dispatch_ptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -202,35 +230,40 @@ class GemmProfiler {
|
||||
}
|
||||
|
||||
// warmup launch
|
||||
dispatch();
|
||||
dispatch_ptr->operator()();
|
||||
|
||||
if (cudaDeviceSynchronize() != cudaSuccess) {
|
||||
result.disposition = Disposition::Failed;
|
||||
delete dispatch_ptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
if (cudaEventRecord(events[0]) != cudaSuccess) {
|
||||
result.disposition = Disposition::Failed;
|
||||
delete dispatch_ptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
dispatch();
|
||||
dispatch_ptr->operator()();
|
||||
}
|
||||
|
||||
if (cudaEventRecord(events[1]) != cudaSuccess) {
|
||||
result.disposition = Disposition::Failed;
|
||||
delete dispatch_ptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
if (cudaEventSynchronize(events[1]) != cudaSuccess) {
|
||||
result.disposition = Disposition::Failed;
|
||||
delete dispatch_ptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
float average_ms = 0;
|
||||
if (cudaEventElapsedTime(&average_ms, events[0], events[1]) != cudaSuccess) {
|
||||
result.disposition = Disposition::Failed;
|
||||
delete dispatch_ptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -242,6 +275,7 @@ class GemmProfiler {
|
||||
<< " failed with disposition: " << result.disposition << "\n";
|
||||
}
|
||||
|
||||
delete dispatch_ptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -265,7 +299,7 @@ class GemmProfiler {
|
||||
|
||||
std::vector<PerformanceResult<GemmProblem> > results;
|
||||
|
||||
results.push_back(execute_cutlass<CutlassDispatch>(problem, algorithm));
|
||||
results.push_back(execute_cutlass<CutlassDispatch>(problem, algorithm));
|
||||
// cool-down period
|
||||
if (!options.dry_run) {
|
||||
pause(options.sleep_time);
|
||||
@ -276,28 +310,30 @@ class GemmProfiler {
|
||||
|
||||
/// Runs the test and collects performance for all results
|
||||
template <typename CutlassDispatch>
|
||||
void schmoo(Range const &M, Range const &N, Range const &K) {
|
||||
for (int m = M.start; m <= M.end; m = M.next(m)) {
|
||||
for (int n = N.start; n <= N.end; n = N.next(n)) {
|
||||
for (int k = K.start; k <= K.end; k = K.next(k)) {
|
||||
|
||||
std::vector<PerformanceResult<GemmProblem> > results =
|
||||
void schmoo(Range const &M, Range const &N, Range const &K, Range const &batch_count) {
|
||||
for (int b = batch_count.start; b <= batch_count.end; b = batch_count.next(b)) {
|
||||
for (int m = M.start; m <= M.end; m = M.next(m)) {
|
||||
for (int n = N.start; n <= N.end; n = N.next(n)) {
|
||||
for (int k = K.start; k <= K.end; k = K.next(k)) {
|
||||
std::vector<PerformanceResult<GemmProblem> > results =
|
||||
execute<CutlassDispatch>(GemmProblem(m,
|
||||
n,
|
||||
k,
|
||||
CutlassDispatch::kLayoutA,
|
||||
CutlassDispatch::kLayoutB,
|
||||
config.alpha,
|
||||
config.beta));
|
||||
n,
|
||||
k,
|
||||
CutlassDispatch::kLayoutA,
|
||||
CutlassDispatch::kLayoutB,
|
||||
config.alpha,
|
||||
config.beta,
|
||||
b));
|
||||
|
||||
for (std::vector<PerformanceResult<GemmProblem> >::const_iterator it = results.begin();
|
||||
it != results.end();
|
||||
++it) {
|
||||
output.append(*it);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (std::vector<PerformanceResult<GemmProblem> >::const_iterator it = results.begin();
|
||||
it != results.end();
|
||||
++it) {
|
||||
output.append(*it);
|
||||
}
|
||||
}//k
|
||||
}//n
|
||||
}//m
|
||||
}//batch_count
|
||||
}
|
||||
|
||||
/// Runs the test over the problem space and reports only the best performance
|
||||
@ -369,7 +405,7 @@ int profile_gemm(TestbenchOutput<GemmProblem> &output,
|
||||
config.problem_range.M, config.problem_range.N, config.problem_range.K);
|
||||
} else {
|
||||
perf.template schmoo<Dispatch>(
|
||||
config.problem_range.M, config.problem_range.N, config.problem_range.K);
|
||||
config.problem_range.M, config.problem_range.N, config.problem_range.K, config.problem_range.batch_count);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
202
tools/test/perf/gemm/igemm_splitK.cu
Normal file
202
tools/test/perf/gemm/igemm_splitK.cu
Normal file
@ -0,0 +1,202 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/igemm_traits.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "cutlass/gemm/device_gemm_traits.h"
|
||||
#include "tools/test/perf/cutlass_perf_test.h"
|
||||
#include "tools/test/perf/gemm/gemm_perf_testbed.h"
|
||||
#include "tools/test/perf/gemm/gemm_profiler.h"
|
||||
#include "tools/test/perf/gemm/cutlass_dispatch.h"
|
||||
#include "tools/test/perf/gemm/cutlass_dispatch_splitK_PI.h"
|
||||
#pragma warning( disable : 4503)
|
||||
|
||||
namespace perf {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutputTile, int splits_count>
|
||||
int profile_igemm_splitkpi_kernel(
|
||||
TestbenchOutput<GemmProblem> &output,
|
||||
TestbenchOptions const &options,
|
||||
Config const &config,
|
||||
std::string const &name,
|
||||
std::string const &algo) {
|
||||
|
||||
typedef perf::GemmProfiler<int8_t, int8_t, int, int, int> GemmProfiler;
|
||||
|
||||
int results = 0;
|
||||
|
||||
{
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
OutputTile
|
||||
> IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
// create a device gemm
|
||||
typedef typename cutlass::gemm::SplitkPIGemmTraits<IgemmTraits, BatchedReductionTraits> deviceGemmTraits;
|
||||
typedef typename CutlassDispatchSplitKPIGemmBasic<deviceGemmTraits>::Dispatch Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, name + "_nn", options, config, algo + "_splitk_pi");
|
||||
}
|
||||
|
||||
{
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
OutputTile
|
||||
> IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
// create a device gemm
|
||||
typedef typename cutlass::gemm::SplitkPIGemmTraits<IgemmTraits, BatchedReductionTraits> deviceGemmTraits;
|
||||
typedef typename CutlassDispatchSplitKPIGemmBasic<deviceGemmTraits>::Dispatch Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, name + "_nt", options, config, algo + "_splitk_pi");
|
||||
}
|
||||
|
||||
{
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
OutputTile
|
||||
> IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
// create a device gemm
|
||||
typedef typename cutlass::gemm::SplitkPIGemmTraits<IgemmTraits, BatchedReductionTraits> deviceGemmTraits;
|
||||
typedef typename CutlassDispatchSplitKPIGemmBasic<deviceGemmTraits>::Dispatch Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, name + "_tn", options, config, algo + "_splitk_pi");
|
||||
}
|
||||
|
||||
{
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
OutputTile
|
||||
> IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
// create a device gemm
|
||||
typedef typename cutlass::gemm::SplitkPIGemmTraits<IgemmTraits, BatchedReductionTraits> deviceGemmTraits;
|
||||
typedef typename CutlassDispatchSplitKPIGemmBasic<deviceGemmTraits>::Dispatch Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, name + "_tt", options, config, algo + "_splitk_pi");
|
||||
}
|
||||
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/// Profiles all SGEMM tile sizes
|
||||
int profile_igemm_splitkpi(TestbenchOutput<GemmProblem> &output, TestbenchOptions const &options, Config const &config) {
|
||||
int results = 0;
|
||||
/*128x128x32*/
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 128, 128>, 8 >(output, options, config, "igemm_128x128x32_splitk_pi_split8", "128x128");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 128, 128>, 16 >(output, options, config, "igemm_128x128x32_splitk_pi_split16", "128x128");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 128, 128>, 32 >(output, options, config, "igemm_128x128x32_splitk_pi_split32", "128x128");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 128, 128>, 64 >(output, options, config, "igemm_128x128x32_splitk_pi_split64", "128x128");
|
||||
|
||||
/*128x64x32*/
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 64, 128>, 8 >(output, options, config, "igemm_128x64x32_splitk_pi_split8", "128x64");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 64, 128>, 16 >(output, options, config, "igemm_128x64x32_splitk_pi_split16", "128x64");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 64, 128>, 20 >(output, options, config, "igemm_128x64x32_splitk_pi_split20", "128x64");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 64, 128>, 32 >(output, options, config, "igemm_128x64x32_splitk_pi_split32", "128x64");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 64, 128>, 64 >(output, options, config, "igemm_128x64x32_splitk_pi_split64", "128x64");
|
||||
|
||||
/*128x32x32*/
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 32, 128>, 8 >(output, options, config, "igemm_128x32x32_splitk_pi_split8", "128x32");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 32, 128>, 16 >(output, options, config, "igemm_128x32x32_splitk_pi_split16", "128x32");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 32, 128>, 20 >(output, options, config, "igemm_128x32x32_splitk_pi_split20", "128x32");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 32, 128>, 32 >(output, options, config, "igemm_128x32x32_splitk_pi_split32", "128x32");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 32, 128>, 64 >(output, options, config, "igemm_128x32x32_splitk_pi_split64", "128x32");
|
||||
|
||||
/*64x64x32*/
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 64, 64>, 8 >(output, options, config, "igemm_64x64x32_splitk_pi_split8", "64x64");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 64, 64>, 16 >(output, options, config, "igemm_64x64x32_splitk_pi_split16", "64x64");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 64, 64>, 20 >(output, options, config, "igemm_64x64x32_splitk_pi_split20", "64x64");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 64, 64>, 32 >(output, options, config, "igemm_64x64x32_splitk_pi_split32", "64x64");
|
||||
results |= profile_igemm_splitkpi_kernel<cutlass::Shape<32, 64, 64>, 64 >(output, options, config, "igemm_64x64x32_splitk_pi_split64", "64x64");
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
struct IgemmSplitKPIRegistrar {
|
||||
IgemmSplitKPIRegistrar() { RegisterGemmProfileFunc(profile_igemm_splitkpi); }
|
||||
};
|
||||
|
||||
volatile IgemmSplitKPIRegistrar _IgemmSplitKPIRegistrar;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace perf
|
||||
File diff suppressed because it is too large
Load Diff
187
tools/test/perf/gemm/sgemm_splitK.cu
Normal file
187
tools/test/perf/gemm/sgemm_splitK.cu
Normal file
@ -0,0 +1,187 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/sgemm_traits.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "cutlass/gemm/device_gemm_traits.h"
|
||||
#include "tools/test/perf/cutlass_perf_test.h"
|
||||
#include "tools/test/perf/gemm/gemm_perf_testbed.h"
|
||||
#include "tools/test/perf/gemm/gemm_profiler.h"
|
||||
#include "tools/test/perf/gemm/cutlass_dispatch.h"
|
||||
#include "tools/test/perf/gemm/cutlass_dispatch_splitK_PI.h"
|
||||
#pragma warning( disable : 4503)
|
||||
|
||||
namespace perf {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutputTile, int splits_count>
|
||||
int profile_sgemm_splitkpi_kernel(
|
||||
TestbenchOutput<GemmProblem> &output,
|
||||
TestbenchOptions const &options,
|
||||
Config const &config,
|
||||
std::string const &name,
|
||||
std::string const &algo) {
|
||||
|
||||
typedef perf::GemmProfiler<float, float, float, float, float> SGemmProfiler;
|
||||
|
||||
int results = 0;
|
||||
|
||||
{
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, OutputTile>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
// create a device gemm
|
||||
typedef typename cutlass::gemm::SplitkPIGemmTraits<SgemmTraits, BatchedReductionTraits> deviceGemmTraits;
|
||||
typedef typename CutlassDispatchSplitKPIGemmBasic<deviceGemmTraits>::Dispatch Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, SGemmProfiler>(output, name + "_nn", options, config, algo + "_splitk_pi");
|
||||
}
|
||||
|
||||
{
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, OutputTile>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
// create a device gemm
|
||||
typedef typename cutlass::gemm::SplitkPIGemmTraits<SgemmTraits, BatchedReductionTraits> deviceGemmTraits;
|
||||
typedef typename CutlassDispatchSplitKPIGemmBasic<deviceGemmTraits>::Dispatch Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, SGemmProfiler>(output, name + "_nt", options, config, algo + "_splitk_pi");
|
||||
}
|
||||
|
||||
{
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, OutputTile>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
// create a device gemm
|
||||
typedef typename cutlass::gemm::SplitkPIGemmTraits<SgemmTraits, BatchedReductionTraits> deviceGemmTraits;
|
||||
typedef typename CutlassDispatchSplitKPIGemmBasic<deviceGemmTraits>::Dispatch Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, SGemmProfiler>(output, name + "_tn", options, config, algo + "_splitk_pi");
|
||||
}
|
||||
|
||||
{
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, OutputTile>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
// create a device gemm
|
||||
typedef typename cutlass::gemm::SplitkPIGemmTraits<SgemmTraits, BatchedReductionTraits> deviceGemmTraits;
|
||||
typedef typename CutlassDispatchSplitKPIGemmBasic<deviceGemmTraits>::Dispatch Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, SGemmProfiler>(output, name + "_tt", options, config, algo + "_splitk_pi");
|
||||
}
|
||||
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/// Profiles all SGEMM tile sizes
|
||||
int profile_sgemm_splitkpi(TestbenchOutput<GemmProblem> &output, TestbenchOptions const &options, Config const &config) {
|
||||
int results = 0;
|
||||
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 128, 128>, 32 >(output, options, config, "sgemm_128x128x8_splitk_pi_split32", "128x128");
|
||||
|
||||
/*128x64x8*/
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 64, 128>, 8 >(output, options, config, "sgemm_128x64x8_splitk_pi_split8", "128x64");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 64, 128>, 16 >(output, options, config, "sgemm_128x64x8_splitk_pi_split16", "128x64");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 64, 128>, 20 >(output, options, config, "sgemm_128x64x8_splitk_pi_split20", "128x64");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 64, 128>, 24 >(output, options, config, "sgemm_128x64x8_splitk_pi_split24", "128x64");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 64, 128>, 28 >(output, options, config, "sgemm_128x64x8_splitk_pi_split28", "128x64");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 64, 128>, 32 >(output, options, config, "sgemm_128x64x8_splitk_pi_split32", "128x64");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 64, 128>, 64 >(output, options, config, "sgemm_128x64x8_splitk_pi_split64", "128x64");
|
||||
/*128x32x8*/
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 32, 128>, 8 >(output, options, config, "sgemm_128x32x8_splitk_pi_split8", "128x32");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 32, 128>, 16 >(output, options, config, "sgemm_128x32x8_splitk_pi_split16", "128x32");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 32, 128>, 20 >(output, options, config, "sgemm_128x32x8_splitk_pi_split20", "128x32");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 32, 128>, 24 >(output, options, config, "sgemm_128x32x8_splitk_pi_split24", "128x32");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 32, 128>, 28 >(output, options, config, "sgemm_128x32x8_splitk_pi_split28", "128x32");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 32, 128>, 32 >(output, options, config, "sgemm_128x32x8_splitk_pi_split32", "128x32");
|
||||
results |= profile_sgemm_splitkpi_kernel<cutlass::Shape<8, 32, 128>, 64 >(output, options, config, "sgemm_128x32x8_splitk_pi_split64", "128x32");
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
struct SgemmSplitKPIRegistrar {
|
||||
SgemmSplitKPIRegistrar() { RegisterGemmProfileFunc(profile_sgemm_splitkpi); }
|
||||
};
|
||||
|
||||
volatile SgemmSplitKPIRegistrar _SgemmSplitKPIRegistrar;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace perf
|
||||
@ -76,6 +76,15 @@ struct WmmaBinaryGemmDispatch {
|
||||
params.initialize(m, n, k * 32, alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd);
|
||||
}
|
||||
|
||||
/// batched strided bmma
|
||||
WmmaBinaryGemmDispatch(int m, int n, int k, int alpha,
|
||||
cutlass::Vector<cutlass::bin1_t, 32> const* d_a, int lda, long long int batch_stride_a,
|
||||
cutlass::Vector<cutlass::bin1_t, 32> const* d_b, int ldb, long long int batch_stride_b, int beta,
|
||||
int const* d_c, int ldc, long long int batch_stride_c, int* d_d, int ldd, long long int batch_stride_d,
|
||||
int batch_count) {
|
||||
assert(0);
|
||||
}
|
||||
|
||||
/// Initializes params object
|
||||
WmmaBinaryGemmDispatch(Params const& _params) : params(_params) {}
|
||||
|
||||
|
||||
@ -1,27 +1,27 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
@ -92,6 +92,31 @@ struct WmmaGemmDispatch {
|
||||
params.initialize(m, n, k, alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd);
|
||||
}
|
||||
|
||||
WmmaGemmDispatch(int m,
|
||||
int n,
|
||||
int k,
|
||||
Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
int lda,
|
||||
long long int batch_stride_A,
|
||||
ScalarB const* d_b,
|
||||
int ldb,
|
||||
long long int batch_stride_B,
|
||||
Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
int ldc,
|
||||
long long int batch_stride_C,
|
||||
ScalarD* d_d,
|
||||
int ldd,
|
||||
long long int batch_stride_D,
|
||||
int batch_count) {
|
||||
params.initialize(m, n, k, alpha, d_a, lda, batch_stride_A,
|
||||
d_b, ldb, batch_stride_B,
|
||||
beta, d_c, ldc, batch_stride_C,
|
||||
d_d, ldd, batch_stride_D,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
/// Initializes params object
|
||||
WmmaGemmDispatch(Params const& _params) : params(_params) {}
|
||||
|
||||
@ -105,6 +130,7 @@ namespace perf {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename DummyT>
|
||||
int profile_wmma_gemm_f32(TestbenchOutput<GemmProblem> &output, TestbenchOptions const &options, Config const &config) {
|
||||
typedef perf::GemmProfiler<cutlass::half_t, cutlass::half_t, float, float, float> GemmProfiler;
|
||||
|
||||
@ -112,8 +138,8 @@ int profile_wmma_gemm_f32(TestbenchOutput<GemmProblem> &output, TestbenchOptions
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor>
|
||||
WmmaGemmTraits;
|
||||
cutlass::MatrixLayout::kRowMajor>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
|
||||
@ -122,8 +148,8 @@ int profile_wmma_gemm_f32(TestbenchOutput<GemmProblem> &output, TestbenchOptions
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor>
|
||||
WmmaGemmTraits;
|
||||
cutlass::MatrixLayout::kColumnMajor>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
|
||||
@ -132,7 +158,7 @@ int profile_wmma_gemm_f32(TestbenchOutput<GemmProblem> &output, TestbenchOptions
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor>
|
||||
cutlass::MatrixLayout::kColumnMajor>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
@ -142,7 +168,7 @@ int profile_wmma_gemm_f32(TestbenchOutput<GemmProblem> &output, TestbenchOptions
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::WmmaGemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor>
|
||||
cutlass::MatrixLayout::kRowMajor>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
@ -155,10 +181,11 @@ int profile_wmma_gemm_f32(TestbenchOutput<GemmProblem> &output, TestbenchOptions
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename DummyT>
|
||||
int profile_wmma_gemm_f16(
|
||||
TestbenchOutput<GemmProblem> &output,
|
||||
TestbenchOptions const &options,
|
||||
Config const &config) {
|
||||
TestbenchOutput<GemmProblem> &output,
|
||||
TestbenchOptions const &options,
|
||||
Config const &config) {
|
||||
|
||||
typedef perf::GemmProfiler<
|
||||
cutlass::half_t,
|
||||
@ -173,7 +200,7 @@ int profile_wmma_gemm_f16(
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 128, 128>,
|
||||
cutlass::Shape<32, 256, 128>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
@ -192,7 +219,7 @@ int profile_wmma_gemm_f16(
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 128, 128>,
|
||||
cutlass::Shape<32, 256, 128>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
@ -211,7 +238,7 @@ int profile_wmma_gemm_f16(
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 128, 128>,
|
||||
cutlass::Shape<32, 256, 128>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
@ -230,7 +257,7 @@ int profile_wmma_gemm_f16(
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 128, 128>,
|
||||
cutlass::Shape<32, 256, 128>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
@ -248,12 +275,283 @@ int profile_wmma_gemm_f16(
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
template <typename DummyT>
|
||||
int profile_wmma_4_gemm_f16(
|
||||
TestbenchOutput<GemmProblem> &output,
|
||||
TestbenchOptions const &options,
|
||||
Config const &config) {
|
||||
|
||||
typedef perf::GemmProfiler<
|
||||
cutlass::half_t,
|
||||
cutlass::half_t,
|
||||
cutlass::half_t,
|
||||
cutlass::half_t,
|
||||
cutlass::half_t> GemmProfiler;
|
||||
|
||||
int results = 0;
|
||||
|
||||
// a set of test requires leading dim to be multiple of 4 instead of 8
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "wmma_4_gemm_f16_nt", options, config);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "wmma_4_gemm_f16_nn", options, config);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "wmma_4_gemm_f16_tn", options, config);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "wmma_4_gemm_f16_tt", options, config);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
template <typename DummyT>
|
||||
int profile_wmma_4_fp16_sgemm_fp16(
|
||||
TestbenchOutput<GemmProblem> &output,
|
||||
TestbenchOptions const &options,
|
||||
Config const &config) {
|
||||
|
||||
typedef perf::GemmProfiler<
|
||||
cutlass::half_t,
|
||||
cutlass::half_t,
|
||||
cutlass::half_t,
|
||||
float,
|
||||
float> GemmProfiler;
|
||||
|
||||
int results = 0;
|
||||
|
||||
// a set of test requires leading dim to be multiple of 4 instead of 8
|
||||
|
||||
{
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "wmma_4_fp16_sgemm_fp16_nt", options, config);
|
||||
}
|
||||
|
||||
{
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "wmma_4_fp16_sgemm_fp16_nn", options, config);
|
||||
}
|
||||
|
||||
{
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "wmma_4_fp16_sgemm_fp16_tn", options, config);
|
||||
}
|
||||
|
||||
{
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
typedef WmmaGemmDispatch<WmmaGemmTraits> Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "wmma_4_fp16_sgemm_fp16_tt", options, config);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct WmmaGemmRegistrar {
|
||||
WmmaGemmRegistrar() {
|
||||
RegisterGemmProfileFunc(profile_wmma_gemm_f32);
|
||||
RegisterGemmProfileFunc(profile_wmma_gemm_f16);
|
||||
RegisterGemmProfileFunc(profile_wmma_gemm_f32<void>);
|
||||
RegisterGemmProfileFunc(profile_wmma_gemm_f16<void>);
|
||||
|
||||
//#ifdef EXHAUSTIVE_PROF
|
||||
RegisterGemmProfileFunc(profile_wmma_4_gemm_f16<void>);
|
||||
//fp32 accum with fp16 input and output
|
||||
RegisterGemmProfileFunc(profile_wmma_4_fp16_sgemm_fp16<void>);
|
||||
//#endif // defined EXHAUSTIVE_PROF
|
||||
}
|
||||
};
|
||||
|
||||
@ -266,3 +564,4 @@ volatile WmmaGemmRegistrar _WmmaGemmRegistrar;
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
|
||||
|
||||
@ -74,6 +74,15 @@ struct WmmaIntegerGemmDispatch {
|
||||
params.initialize(m, n, k, alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd);
|
||||
}
|
||||
|
||||
///
|
||||
WmmaIntegerGemmDispatch(int m, int n, int k, int alpha,
|
||||
ScalarA const* d_a, int lda, long long int batch_stride_a,
|
||||
ScalarB const* d_b, int ldb, long long int batch_stride_b, int beta,
|
||||
int const* d_c, int ldc, long long int batch_stride_c, int* d_d, int ldd, long long int batch_stride_d,
|
||||
int batch_count) {
|
||||
assert(0);
|
||||
}
|
||||
|
||||
/// Initializes params object
|
||||
WmmaIntegerGemmDispatch(Params const& _params) : params(_params) {}
|
||||
|
||||
@ -125,6 +134,15 @@ struct WmmaIntegerGemmDispatch<Traits,
|
||||
params.initialize(m, n, k * 8, alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd);
|
||||
}
|
||||
|
||||
///
|
||||
WmmaIntegerGemmDispatch(int m, int n, int k, int alpha,
|
||||
ScalarA const* d_a, int lda, long long int batch_stride_a,
|
||||
ScalarB const* d_b, int ldb, long long int batch_stride_b, int beta,
|
||||
int const* d_c, int ldc, long long int batch_stride_c, int* d_d, int ldd, long long int batch_stride_d,
|
||||
int batch_count) {
|
||||
assert(0);
|
||||
}
|
||||
|
||||
/// Initializes params object
|
||||
WmmaIntegerGemmDispatch(Params const& _params) : params(_params) {}
|
||||
|
||||
@ -176,6 +194,15 @@ struct WmmaIntegerGemmDispatch<Traits,
|
||||
params.initialize(m, n, k * 8, alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd);
|
||||
}
|
||||
|
||||
///
|
||||
WmmaIntegerGemmDispatch(int m, int n, int k, int alpha,
|
||||
ScalarA const* d_a, int lda, long long int batch_stride_a,
|
||||
ScalarB const* d_b, int ldb, long long int batch_stride_b, int beta,
|
||||
int const* d_c, int ldc, long long int batch_stride_c, int* d_d, int ldd, long long int batch_stride_d,
|
||||
int batch_count) {
|
||||
assert(0);
|
||||
}
|
||||
|
||||
/// Initializes params object
|
||||
WmmaIntegerGemmDispatch(Params const& _params) : params(_params) {}
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "tools/util/command_line.h"
|
||||
#include "tools/test/perf/provider.h"
|
||||
@ -85,6 +85,7 @@ struct GemmProblem {
|
||||
int m;
|
||||
int n;
|
||||
int k;
|
||||
int batch_count;
|
||||
cutlass::MatrixLayout::Kind layout_A;
|
||||
cutlass::MatrixLayout::Kind layout_B;
|
||||
|
||||
@ -96,7 +97,7 @@ struct GemmProblem {
|
||||
//
|
||||
|
||||
/// Static method to print GemmProblem headers
|
||||
static std::string header() { return "M,N,K,Layout_A,Layout_B,Beta"; }
|
||||
static std::string header() { return "M,N,K,Layout_A,Layout_B,Beta,batch_count"; }
|
||||
|
||||
//
|
||||
// Methods
|
||||
@ -108,21 +109,24 @@ struct GemmProblem {
|
||||
cutlass::MatrixLayout::Kind _layout_A = cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::Kind _layout_B = cutlass::MatrixLayout::kRowMajor,
|
||||
double _alpha = 1,
|
||||
double _beta = 0)
|
||||
: m(_m), n(_n), k(_k), layout_A(_layout_A), layout_B(_layout_B), alpha(_alpha), beta(_beta) {}
|
||||
double _beta = 0,
|
||||
int _batch_count = 1)
|
||||
: m(_m), n(_n), k(_k), layout_A(_layout_A), layout_B(_layout_B), alpha(_alpha), beta(_beta), batch_count(_batch_count) {
|
||||
assert(batch_count >= 1);
|
||||
}
|
||||
|
||||
/// leading dimension of A
|
||||
int lda() const {
|
||||
if (layout_A == cutlass::MatrixLayout::kColumnMajor) {
|
||||
return m;
|
||||
}
|
||||
return k;
|
||||
return k * batch_count;
|
||||
}
|
||||
|
||||
/// leading dimension of B
|
||||
int ldb() const {
|
||||
if (layout_B == cutlass::MatrixLayout::kColumnMajor) {
|
||||
return k;
|
||||
return k * batch_count;
|
||||
}
|
||||
return n;
|
||||
}
|
||||
@ -130,10 +134,35 @@ struct GemmProblem {
|
||||
/// leading dimension of C
|
||||
int ldc() const { return m; }
|
||||
|
||||
/// batch_stride_a. only makes sense when batch_count > 1
|
||||
long long int batch_stride_a() const {
|
||||
assert(batch_count > 1);
|
||||
if (layout_A == cutlass::MatrixLayout::kColumnMajor) {
|
||||
return static_cast<long long int>(k) * static_cast<long long int>(lda());
|
||||
}
|
||||
return static_cast<long long int>(k);
|
||||
}
|
||||
|
||||
/// batch_stride_b. only makes sense when batch_count > 1
|
||||
long long int batch_stride_b() const {
|
||||
assert(batch_count > 1);
|
||||
if (layout_B == cutlass::MatrixLayout::kColumnMajor) {
|
||||
return static_cast<long long int>(k);
|
||||
}
|
||||
return static_cast<long long int>(k) * static_cast<long long int>(ldb());
|
||||
}
|
||||
|
||||
/// batch_stride_c. only makes sense when batch_count > 1
|
||||
long long int batch_stride_c() const {
|
||||
assert(batch_count > 1);
|
||||
return static_cast<long long int>(n) * static_cast<long long int>(ldc());
|
||||
}
|
||||
|
||||
|
||||
/// Pretty prints output
|
||||
std::ostream &pretty_print(std::ostream &out) const {
|
||||
out << m << "-by-" << n << "-by-" << k << ", A: " << layout_A << "-major, B: " << layout_B
|
||||
<< "-major, beta: " << beta;
|
||||
<< "-major, beta: " << beta << ", batch: " << batch_count;
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -142,7 +171,7 @@ struct GemmProblem {
|
||||
/// Prints a problem to an output stream
|
||||
inline std::ostream &operator<<(std::ostream &out, GemmProblem const &problem) {
|
||||
out << problem.m << "," << problem.n << "," << problem.k << "," << problem.layout_A << ","
|
||||
<< problem.layout_B << "," << problem.beta;
|
||||
<< problem.layout_B << "," << problem.beta << "," << problem.batch_count;
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -125,13 +125,16 @@ struct GemmProblemRange {
|
||||
/// Range of sizes in GEMM K dimension
|
||||
Range K;
|
||||
|
||||
/// Range of sizes in batch dimeion
|
||||
Range batch_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor to define a space of probelm sizes
|
||||
GemmProblemRange(Range _M = Range(256), Range _N = Range(256), Range _K = Range(256))
|
||||
: M(_M), N(_N), K(_K) {}
|
||||
GemmProblemRange(Range _M = Range(256), Range _N = Range(256), Range _K = Range(256), Range _batch_count = Range(1))
|
||||
: M(_M), N(_N), K(_K), batch_count(_batch_count) {}
|
||||
|
||||
/// Parses a command line argument as a Range object
|
||||
static void get_range(Range &range,
|
||||
@ -155,6 +158,7 @@ struct GemmProblemRange {
|
||||
get_range(M, args, "m", Range(10240));
|
||||
get_range(N, args, "n", Range(4096));
|
||||
get_range(K, args, "k", Range(4096));
|
||||
get_range(batch_count, args, "batch", Range(1));
|
||||
}
|
||||
};
|
||||
|
||||
@ -368,7 +372,7 @@ struct TestbenchOptions {
|
||||
|
||||
/// Number of iterations
|
||||
int iterations;
|
||||
|
||||
|
||||
/// Defines how to run the benchmark
|
||||
ExecutionMode::Kind execution_mode;
|
||||
|
||||
@ -599,6 +603,9 @@ struct TestbenchOptions {
|
||||
<< " --k=<depth>[:max depth[:step]] "
|
||||
<< " Size of inner dimension of A and B. May specify a range with optional step size.\n"
|
||||
|
||||
<< " --batch=<batch> "
|
||||
<< " Number of batches for a bached gemm. "
|
||||
|
||||
<< " --kernels=<{s|d|h|i|wmma_|wmma_binary_|wmma_integer_}gemm_{nn,nt,tn,tt}>\n"
|
||||
<< " "
|
||||
<< " Select GEMM datatype and layout to use for tests\n"
|
||||
|
||||
@ -39,10 +39,18 @@ set(CUTLASS_UNIT_TEST_HEADERS
|
||||
core/layout_verification.h
|
||||
gemm/run_gemm.h
|
||||
gemm/gemm_testbed.h
|
||||
reduction/batched_reduction_testbed.h
|
||||
reduction/test_batched_reduction.h
|
||||
)
|
||||
|
||||
set(CUTLASS_UNIT_TEST_SOURCES_BACKUP
|
||||
cutlass_unit_test.cpp
|
||||
gemm/batched_strided_sgemm_128x128x8.cu
|
||||
)
|
||||
|
||||
set(CUTLASS_UNIT_TEST_SOURCES
|
||||
cutlass_unit_test.cpp
|
||||
tile_iterator_test.cu
|
||||
core/tensor_ref.cu
|
||||
core/tensor_view.cu
|
||||
util/unique_ptr.cu
|
||||
@ -80,6 +88,9 @@ set(CUTLASS_UNIT_TEST_SOURCES
|
||||
gemm/fp16_sgemm_fp32_128x128x16.cu
|
||||
gemm/fp16_sgemm_fp16_128x128x16.cu
|
||||
gemm/wmma_gemm.cu
|
||||
gemm/fp16_wmma_gemm_fp16.cu
|
||||
gemm/wmma_gemm_non_multiple16.cu
|
||||
gemm/fp16_wmma_gemm_fp16_non_multiple16.cu
|
||||
gemm/wmma_binary_gemm.cu
|
||||
gemm/wmma_integer_gemm.cu
|
||||
gemm/sgemm_threadblock_swizzle_nn.cu
|
||||
@ -89,7 +100,18 @@ set(CUTLASS_UNIT_TEST_SOURCES
|
||||
gemm/batched_strided_sgemm_128x128x8.cu
|
||||
gemm/batched_strided_dgemm_128x128x8.cu
|
||||
gemm/batched_strided_hgemm_128x128x8.cu
|
||||
gemm/batched_strided_wmma_gemm.cu
|
||||
gemm/batched_strided_fp16_wmma_gemm_fp16.cu
|
||||
gemm/epilogue_functor.cu
|
||||
reduction/batched_reduction.cu
|
||||
reduction/mixed_batched_reduction.cu
|
||||
gemm/splitK_sgemm.cu
|
||||
gemm/splitK_igemm.cu
|
||||
gemm/splitK_fp16_sgemm_fp16.cu
|
||||
gemm/splitK_dgemm.cu
|
||||
gemm/splitK_hgemm.cu
|
||||
gemm/splitK_wmma_gemm.cu
|
||||
gemm/partitionedK_sgemm_128x128x8.cu
|
||||
)
|
||||
|
||||
if (CUTLASS_NVRTC_ENABLE)
|
||||
|
||||
@ -124,120 +124,120 @@ TEST(PredicateVector, Count) {
|
||||
{
|
||||
typedef cutlass::PredicateVector<4, 8> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 1)
|
||||
<< "PredicateVector<4, 8> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<4, 8> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<4, 4> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 1)
|
||||
<< "PredicateVector<4, 4> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<4, 4> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<4, 2> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 1)
|
||||
<< "PredicateVector<4, 2> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<4, 2> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<4, 1> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 1)
|
||||
<< "PredicateVector<4, 1> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<4, 1> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<8, 8> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 1)
|
||||
<< "PredicateVector<8, 8> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<8, 8> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<8, 4> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 1)
|
||||
<< "PredicateVector<8, 4> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<8, 4> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<8, 2> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 1)
|
||||
<< "PredicateVector<8, 2> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<8, 2> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<8, 1> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 2)
|
||||
<< "PredicateVector<8, 1> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<8, 1> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<16, 8> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 1)
|
||||
<< "PredicateVector<16, 8> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<16, 8> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<16, 4> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 1)
|
||||
<< "PredicateVector<16, 4> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<16, 4> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<16, 2> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 2)
|
||||
<< "PredicateVector<16, 2> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<16, 2> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<16, 1> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 4)
|
||||
<< "PredicateVector<16, 1> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<16, 1> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<32, 8> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 1)
|
||||
<< "PredicateVector<32, 8> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<32, 8> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<32, 4> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 2)
|
||||
<< "PredicateVector<32, 4> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<32, 4> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<32, 2> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 4)
|
||||
<< "PredicateVector<32, 2> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<32, 2> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<32, 1> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 8)
|
||||
<< "PredicateVector<32, 1> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<32, 1> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<64, 8> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 2)
|
||||
<< "PredicateVector<64, 8> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<64, 8> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<64, 4> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 4)
|
||||
<< "PredicateVector<64, 4> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<64, 4> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<64, 2> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 8)
|
||||
<< "PredicateVector<64, 2> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<64, 2> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::PredicateVector<64, 1> PredicateVector;
|
||||
EXPECT_EQ(int(PredicateVector::kWordCount), 16)
|
||||
<< "PredicateVector<64, 1> word count: " << PredicateVector::kWordCount;
|
||||
<< "PredicateVector<64, 1> word count: " << int(PredicateVector::kWordCount);
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,15 +64,30 @@ void set_gtest_flag() {
|
||||
/// If true, the tests are enabled strictly for one compute capability
|
||||
bool experimental;
|
||||
} test_filters[] = {
|
||||
{ "Sgemm*", 50, false },
|
||||
{ "Dgemm*", 60, false },
|
||||
{ "Fp16_sgemm*", 60, false },
|
||||
{ "Hgemm*", 60, false },
|
||||
{ "Igemm*", 61, false },
|
||||
{ "WmmaGemm*", 70, false },
|
||||
{ "WmmaInt8*", 72, false },
|
||||
{ "WmmaInt4*", 75, true },
|
||||
{ "WmmaBinary*", 75, true },
|
||||
{ "Sgemm*", 50, false },
|
||||
{ "*sgemm*", 50, false },
|
||||
{ "Dgemm*", 60, false },
|
||||
{ "*dgemm*", 60, false },
|
||||
{ "Fp16_sgemm*", 60, false },
|
||||
{ "*fp16_sgemm*", 60, false },
|
||||
{ "Batched_reduction*", 60, false },
|
||||
{ "*batched_reduction*", 60, false },
|
||||
{ "Float_batched_reduction*", 60, false },
|
||||
{ "*float_batched_reduction*", 60, false },
|
||||
{ "SplitK*", 60, false },
|
||||
{ "*splitK*", 60, false },
|
||||
{ "Hgemm*", 60, false },
|
||||
{ "*hgemm*", 60, false },
|
||||
{ "Igemm*", 61, false },
|
||||
{ "*igemm*", 61, false },
|
||||
{ "WmmaGemm*", 70, false },
|
||||
{ "*wmma*", 70, false },
|
||||
{ "WmmaInt8*", 72, false },
|
||||
{ "*wmmaInt8*", 72, false },
|
||||
{ "WmmaInt4*", 75, true },
|
||||
{ "*wmmaInt4*", 75, true },
|
||||
{ "WmmaBinary*", 75, true },
|
||||
{ "*wmmaBinary*", 75, true },
|
||||
{ 0, 0, false }
|
||||
};
|
||||
|
||||
|
||||
385
tools/test/unit/gemm/batched_strided_fp16_wmma_gemm_fp16.cu
Normal file
385
tools/test/unit/gemm/batched_strided_fp16_wmma_gemm_fp16.cu
Normal file
@ -0,0 +1,385 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#if defined(CUTLASS_USE_WMMA_API)
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/wmma_gemm_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_32x32x16_nn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(32, 32, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_32x32x16_nt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(32, 32, 64, 3);
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_32x32x16_tn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(32, 32, 64, 3);
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_32x32x16_tt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(32, 32, 64, 3);
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//mulitple of 4
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_36x36x16_nn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(36, 36, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_36x36x16_nt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(36, 36, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_36x36x16_tn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(36, 36, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_36x36x16_tt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(36, 36, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//mulitple of 2
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_34x34x16_nn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
2, /*kScalarsPerLdgA_*/
|
||||
2, /*kScalarsPerLdgB_*/
|
||||
2, /*KScalarsPerLdsA_*/
|
||||
2, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(34, 34, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_34x34x16_nt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
2, /*kScalarsPerLdgA_*/
|
||||
2, /*kScalarsPerLdgB_*/
|
||||
2, /*KScalarsPerLdsA_*/
|
||||
2, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(34, 34, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_34x34x16_tn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
2, /*kScalarsPerLdgA_*/
|
||||
2, /*kScalarsPerLdgB_*/
|
||||
2, /*KScalarsPerLdsA_*/
|
||||
2, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(34, 34, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f32, fp16_wmma_gemm_fp16_34x34x16_tt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
2, /*kScalarsPerLdgA_*/
|
||||
2, /*kScalarsPerLdgB_*/
|
||||
2, /*KScalarsPerLdsA_*/
|
||||
2, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(34, 34, 64, 3);
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -34,6 +34,7 @@ TEST(Sgemm_strided_batched_128x128x8, sgemm_256x384x64x3_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
//think about using run_gemm directly
|
||||
run_batched_strided_gemm<SgemmTraits>(256/*m*/, 384/*n*/, 64/*k*/, 3 /*batch_size*/);
|
||||
}
|
||||
|
||||
@ -43,6 +44,7 @@ TEST(Sgemm_strided_batched_128x128x8, sgemm_128x384x192x2_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
//think about using run_gemm directly
|
||||
run_batched_strided_gemm<SgemmTraits>(128/*m*/, 384/*n*/, 192/*k*/, 2 /*batch_size*/);
|
||||
}
|
||||
|
||||
@ -52,6 +54,7 @@ TEST(Sgemm_strided_batched_128x128x8, sgemm_127x384x192x2_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
//think about using run_gemm directly
|
||||
run_batched_strided_gemm<SgemmTraits>(127/*m*/, 384/*n*/, 192/*k*/, 2 /*batch_size*/);
|
||||
}
|
||||
|
||||
@ -61,6 +64,7 @@ TEST(Sgemm_strided_batched_128x128x8, sgemm_127x388x190x2_nn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
//think about using run_gemm directly
|
||||
run_batched_strided_gemm<SgemmTraits>(127/*m*/, 388/*n*/, 190/*k*/, 2 /*batch_size*/);
|
||||
}
|
||||
|
||||
@ -70,6 +74,7 @@ TEST(Sgemm_strided_batched_128x128x8, sgemm_256x384x64x3_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
//think about using run_gemm directly
|
||||
run_batched_strided_gemm<SgemmTraits>(256/*m*/, 384/*n*/, 64/*k*/, 3 /*batch_size*/);
|
||||
}
|
||||
|
||||
@ -79,6 +84,7 @@ TEST(Sgemm_strided_batched_128x128x8, sgemm_128x384x192x2_nt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
//think about using run_gemm directly
|
||||
run_batched_strided_gemm<SgemmTraits>(128/*m*/, 384/*n*/, 192/*k*/, 2 /*batch_size*/);
|
||||
}
|
||||
|
||||
@ -90,6 +96,7 @@ TEST(Sgemm_strided_batched_128x128x8, sgemm_256x384x64x3_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
//think about using run_gemm directly
|
||||
run_batched_strided_gemm<SgemmTraits>(256/*m*/, 384/*n*/, 64/*k*/, 3 /*batch_size*/);
|
||||
}
|
||||
|
||||
@ -99,6 +106,7 @@ TEST(Sgemm_strided_batched_128x128x8, sgemm_128x384x192x2_tn) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
//think about using run_gemm directly
|
||||
run_batched_strided_gemm<SgemmTraits>(128/*m*/, 384/*n*/, 192/*k*/, 2 /*batch_size*/);
|
||||
}
|
||||
|
||||
@ -110,6 +118,7 @@ TEST(Sgemm_strided_batched_128x128x8, sgemm_256x384x64x3_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
//think about using run_gemm directly
|
||||
run_batched_strided_gemm<SgemmTraits>(256/*m*/, 384/*n*/, 64/*k*/, 3 /*batch_size*/);
|
||||
}
|
||||
|
||||
@ -119,8 +128,8 @@ TEST(Sgemm_strided_batched_128x128x8, sgemm_128x384x192x2_tt) {
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
//think about using run_gemm directly
|
||||
run_batched_strided_gemm<SgemmTraits>(128/*m*/, 384/*n*/, 192/*k*/, 2 /*batch_size*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
240
tools/test/unit/gemm/batched_strided_wmma_gemm.cu
Normal file
240
tools/test/unit/gemm/batched_strided_wmma_gemm.cu
Normal file
@ -0,0 +1,240 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#if defined(CUTLASS_USE_WMMA_API)
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/wmma_gemm_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f16, wmma_gemm_32x32x16_nn) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(32, 32, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f16, wmma_gemm_32x32x16_nt) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(32, 32, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f16, wmma_gemm_32x32x16_tn) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(32, 32, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f16, wmma_gemm_32x32x16_tt) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(32, 32, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//non multiple of 16
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f16, wmma_gemm_36x36x16_nn) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(36, 36, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f16, wmma_gemm_36x36x16_nt) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(36, 36, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f16, wmma_gemm_36x36x16_tn) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(36, 36, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_strided_batched_16x16x32_f16, wmma_gemm_36x36x16_tt) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_batched_strided_gemm<WmmaGemmTraits>(36, 36, 64, 3);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#endif
|
||||
@ -28,7 +28,7 @@
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Fp16_sgemm_alphaFp16_fp16_128x128x16, fp16_sgemm_fp16_128x128x16_nn) {
|
||||
@ -319,3 +319,5 @@ TEST(Fp16_sgemm_alphaFp32_fp16_128x128x16, fp16_sgemm_fp16_128x112x17_tt) {
|
||||
run_gemm<SgemmTraits>(128, 112, 17);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
|
||||
|
||||
|
||||
TEST(Fp16_sgemm_alphaFp32_fp32_128x128x16, fp16_sgemm_fp32_128x128x16_nn) {
|
||||
@ -172,3 +172,6 @@ TEST(Fp16_sgemm_alphaFp32_fp32_128x128x16, fp16_sgemm_fp32_128x112x17_tt) {
|
||||
SgemmTraits;
|
||||
run_gemm<SgemmTraits>(128, 112, 17);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
381
tools/test/unit/gemm/fp16_wmma_gemm_fp16.cu
Normal file
381
tools/test/unit/gemm/fp16_wmma_gemm_fp16.cu
Normal file
@ -0,0 +1,381 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#if defined(CUTLASS_USE_WMMA_API)
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/wmma_gemm_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_16x16x16_nn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(16, 16, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_8x8x16_nn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(8, 8, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_256x256x64_nn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_16x16x16_nt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(16, 16, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_8x8x16_nt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(8, 8, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_256x256x64_nt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_16x16x16_tn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(16, 16, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_8x8x16_tn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(8, 8, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_256x256x64_tn) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_16x16x16_tt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(16, 16, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_8x8x16_tt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(8, 8, 16);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_256x256x64_tt) {
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
8, /*kScalarsPerLdgA_*/
|
||||
8, /*kScalarsPerLdgB_*/
|
||||
8, /*KScalarsPerLdsA_*/
|
||||
8, /*KScalarsPerLdsB_*/
|
||||
16 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
16 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
16 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 64);
|
||||
}
|
||||
|
||||
#endif //#if defined(CUTLASS_USE_WMMA_API)
|
||||
273
tools/test/unit/gemm/fp16_wmma_gemm_fp16_non_multiple16.cu
Normal file
273
tools/test/unit/gemm/fp16_wmma_gemm_fp16_non_multiple16.cu
Normal file
@ -0,0 +1,273 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#if defined(CUTLASS_USE_WMMA_API)
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/wmma_gemm_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/* mulitple of 4*/
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_36x36x64_nn) {
|
||||
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(36, 36, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_36x36x64_nt) {
|
||||
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(36, 36, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_36x36x64_tn) {
|
||||
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(36, 36, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_36x36x64_tt) {
|
||||
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
8 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
8 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
8 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(36, 36, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/* mulitple of 2*/
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_34x34x64_nn) {
|
||||
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
2, /*kScalarsPerLdgA_*/
|
||||
2, /*kScalarsPerLdgB_*/
|
||||
2, /*KScalarsPerLdsA_*/
|
||||
2, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(34, 34, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/* mulitple of 2*/
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_34x34x64_nt) {
|
||||
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
2, /*kScalarsPerLdgA_*/
|
||||
2, /*kScalarsPerLdgB_*/
|
||||
2, /*KScalarsPerLdsA_*/
|
||||
2, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(34, 34, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/* mulitple of 2*/
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_34x34x64_tn) {
|
||||
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
2, /*kScalarsPerLdgA_*/
|
||||
2, /*kScalarsPerLdgB_*/
|
||||
2, /*KScalarsPerLdsA_*/
|
||||
2, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(34, 34, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/* mulitple of 2*/
|
||||
TEST(WmmaGemm_16x16x32_fp32, fp16_wmma_gemm_fp16_34x34x64_tt) {
|
||||
|
||||
typedef float accumu_type;
|
||||
typedef half c_type;
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
c_type,
|
||||
cutlass::gemm::LinearScaling<accumu_type>,
|
||||
accumu_type,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
2, /*kScalarsPerLdgA_*/
|
||||
2, /*kScalarsPerLdgB_*/
|
||||
2, /*KScalarsPerLdsA_*/
|
||||
2, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(c_type), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(accumu_type), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(accumu_type) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(34, 34, 64);
|
||||
}
|
||||
#endif
|
||||
382
tools/test/unit/gemm/gemm_load_global_store_shared.cu
Normal file
382
tools/test/unit/gemm/gemm_load_global_store_shared.cu
Normal file
@ -0,0 +1,382 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass_unit_tests.h"
|
||||
#include "tools/util/host_tensor.h"
|
||||
#include "tools/test/unit/core/layout_verification.h"
|
||||
#include "tools/util/tensor_view_io.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/shape.h"
|
||||
#include "cutlass/gemm/sgemm_traits.h"
|
||||
#include "cutlass/gemm/dgemm_traits.h"
|
||||
#include "cutlass/gemm/hgemm_traits.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace test {
|
||||
|
||||
// M/N/K struct.
|
||||
struct GemmDesc {
|
||||
int m, n, k;
|
||||
CUTLASS_HOST_DEVICE GemmDesc(int m_, int n_, int k_) : m(m_), n(n_), k(k_) {}
|
||||
};
|
||||
|
||||
/// Simple test to load from global memory and store to shared memory
|
||||
|
||||
// Loading from global memory and storing to shared memory for A
|
||||
template <typename Traits>
|
||||
__global__ void Gemm_load_global_store_shared_a(
|
||||
typename Traits::GlobalLoadStreamA::Scalar *output,
|
||||
typename Traits::GlobalLoadStreamA::Scalar const *input,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int ldm) {
|
||||
|
||||
//Create shared memory.
|
||||
__shared__ typename Traits::SharedStorage shared_storage;
|
||||
|
||||
// Create those iterators.
|
||||
typedef typename Traits::GlobalLoadStreamA GlobalLoadStreamA;
|
||||
|
||||
typename GlobalLoadStreamA::Params global_load_params;
|
||||
GemmDesc desc(M, N, K);
|
||||
global_load_params.initialize(desc, input, ldm);
|
||||
|
||||
GlobalLoadStreamA stream_a(global_load_params, shared_storage.main_loop.stream_a.global, M, N, K, cutlass::make_Coord(0, 0, 0));
|
||||
stream_a.copy();
|
||||
stream_a.commit();
|
||||
|
||||
// store barrier
|
||||
__syncthreads();
|
||||
|
||||
// one thread writes everything out
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < M*K; ++i) {
|
||||
output[i] = shared_storage.main_loop.stream_a.shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Loading from global memory and storing to shared memory for B
|
||||
template <typename Traits>
|
||||
__global__ void Gemm_load_global_store_shared_b(
|
||||
typename Traits::GlobalLoadStreamB::Scalar *output,
|
||||
typename Traits::GlobalLoadStreamB::Scalar const *input,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int ldm) {
|
||||
|
||||
//Create shared memory.
|
||||
__shared__ typename Traits::SharedStorage shared_storage;
|
||||
|
||||
// Create those iterators.
|
||||
typedef typename Traits::GlobalLoadStreamB GlobalLoadStreamB;
|
||||
typename GlobalLoadStreamB::Params global_load_params;
|
||||
GemmDesc desc(M, N, K);
|
||||
global_load_params.initialize(desc, input, ldm);
|
||||
|
||||
GlobalLoadStreamB stream_b(global_load_params, shared_storage.main_loop.stream_b.global, M, N, K, cutlass::make_Coord(0, 0, 0));
|
||||
stream_b.copy();
|
||||
stream_b.commit();
|
||||
|
||||
// store barrier
|
||||
__syncthreads();
|
||||
|
||||
// one thread writes everything out
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < M*K; ++i) {
|
||||
output[i] = shared_storage.main_loop.stream_b.shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename CtaTile, // concept: Shape
|
||||
typename DestType, // raw data type
|
||||
typename SourceType // raw data type
|
||||
>
|
||||
class VerifyDataMovement {
|
||||
public:
|
||||
|
||||
/// Tensor to store the destination data
|
||||
cutlass::HostTensor<DestType> destination;
|
||||
|
||||
/// Tensor to store the source data
|
||||
cutlass::HostTensor<SourceType> source;
|
||||
|
||||
/// Verification utility
|
||||
typedef test::VerifyLayout<
|
||||
DestType,
|
||||
test::CoordinatePack<DestType>,
|
||||
SourceType,
|
||||
test::CoordinatePack<SourceType> > VerifyLayout;
|
||||
|
||||
/// Verification object
|
||||
VerifyLayout verify_layout;
|
||||
|
||||
public:
|
||||
|
||||
VerifyDataMovement() { }
|
||||
|
||||
VerifyDataMovement(test::Layout const &source_layout) {
|
||||
|
||||
// Actual layout here doesn't matter here, just the number of elements
|
||||
destination.resize_matrix(CtaTile::kH, CtaTile::kW, cutlass::MatrixLayout::kRowMajor);
|
||||
source.resize_matrix(CtaTile::kH, CtaTile::kW, cutlass::MatrixLayout::kRowMajor);
|
||||
|
||||
verify_layout.initialize(source, source_layout);
|
||||
destination.fill(0);
|
||||
|
||||
destination.sync_device();
|
||||
source.sync_device();
|
||||
}
|
||||
|
||||
/// Verifies resulting layout
|
||||
bool verify(test::Layout const & destination_layout) {
|
||||
|
||||
destination.sync_host();
|
||||
|
||||
typename VerifyLayout::VisitorVerbose visitor(std::cout);
|
||||
|
||||
bool passed = verify_layout.verify(
|
||||
destination,
|
||||
destination_layout,
|
||||
visitor);
|
||||
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(Gemm_shared_tile, A_float_contiguous) {
|
||||
|
||||
static int const M = 64;
|
||||
static int const N = 64;
|
||||
static int const K = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<K, N, M> >
|
||||
SgemmTraits;
|
||||
|
||||
typedef test::Layout::Span Span;
|
||||
test::Layout::SpanVector dst_layout;
|
||||
test::Layout::SpanVector src_layout;
|
||||
|
||||
// define the source layout
|
||||
src_layout.push_back(Span(0, K));
|
||||
src_layout.push_back(Span(1, M));
|
||||
|
||||
typedef VerifyDataMovement<
|
||||
cutlass::Shape<1, M, K, 1>,
|
||||
float,
|
||||
float
|
||||
> VerifyDataMovement_t;
|
||||
|
||||
VerifyDataMovement_t testbed(src_layout);
|
||||
|
||||
|
||||
test::Gemm_load_global_store_shared_a< SgemmTraits ><<<
|
||||
dim3(1,1,1),
|
||||
dim3(SgemmTraits::kThreads, 1)
|
||||
>>>(
|
||||
testbed.destination.device_data(),
|
||||
testbed.source.device_data(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
M
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
ASSERT_EQ(result, cudaSuccess) << "\nCUDA kernel launch error: " << cudaGetErrorString(result)
|
||||
<< "\n";
|
||||
|
||||
// define the destination layout
|
||||
dst_layout.push_back(Span(0, K));
|
||||
dst_layout.push_back(Span(1, M));
|
||||
|
||||
EXPECT_TRUE(testbed.verify(dst_layout));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(Gemm_shared_tile, A_double_contiguous) {
|
||||
|
||||
static int const M = 64;
|
||||
static int const N = 64;
|
||||
static int const K = 8;
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<K, N, M> >
|
||||
DgemmTraits;
|
||||
|
||||
typedef test::Layout::Span Span;
|
||||
test::Layout::SpanVector dst_layout;
|
||||
test::Layout::SpanVector src_layout;
|
||||
|
||||
// define the source layout
|
||||
src_layout.push_back(Span(0, K));
|
||||
src_layout.push_back(Span(1, M));
|
||||
|
||||
typedef VerifyDataMovement<
|
||||
cutlass::Shape<1, M, K, 1>,
|
||||
double,
|
||||
double
|
||||
> VerifyDataMovement_t;
|
||||
|
||||
VerifyDataMovement_t testbed(src_layout);
|
||||
|
||||
test::Gemm_load_global_store_shared_a< DgemmTraits ><<<
|
||||
dim3(1,1,1),
|
||||
dim3(DgemmTraits::kThreads, 1)
|
||||
>>>(
|
||||
testbed.destination.device_data(),
|
||||
testbed.source.device_data(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
M
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
ASSERT_EQ(result, cudaSuccess) << "\nCUDA kernel launch error: " << cudaGetErrorString(result)
|
||||
<< "\n";
|
||||
|
||||
// define the destination layout
|
||||
dst_layout.push_back(Span(0, K));
|
||||
dst_layout.push_back(Span(1, M));
|
||||
|
||||
EXPECT_TRUE(testbed.verify(dst_layout));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(Gemm_shared_tile, B_float_contiguous) {
|
||||
|
||||
static int const M = 64;
|
||||
static int const N = 64;
|
||||
static int const K = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<K, N, M> >
|
||||
SgemmTraits;
|
||||
|
||||
typedef test::Layout::Span Span;
|
||||
test::Layout::SpanVector dst_layout;
|
||||
test::Layout::SpanVector src_layout;
|
||||
|
||||
// define the source layout
|
||||
src_layout.push_back(Span(0, K));
|
||||
src_layout.push_back(Span(1, M));
|
||||
|
||||
typedef VerifyDataMovement<
|
||||
cutlass::Shape<1, M, K, 1>,
|
||||
float,
|
||||
float
|
||||
> VerifyDataMovement_t;
|
||||
|
||||
VerifyDataMovement_t testbed(src_layout);
|
||||
|
||||
|
||||
test::Gemm_load_global_store_shared_b< SgemmTraits ><<<
|
||||
dim3(1,1,1),
|
||||
dim3(SgemmTraits::kThreads, 1)
|
||||
>>>(
|
||||
testbed.destination.device_data(),
|
||||
testbed.source.device_data(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
M
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
ASSERT_EQ(result, cudaSuccess) << "\nCUDA kernel launch error: " << cudaGetErrorString(result)
|
||||
<< "\n";
|
||||
|
||||
// define the destination layout
|
||||
dst_layout.push_back(Span(0, K));
|
||||
dst_layout.push_back(Span(1, M));
|
||||
|
||||
EXPECT_TRUE(testbed.verify(dst_layout));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(Gemm_shared_tile, B_double_contiguous) {
|
||||
|
||||
static int const M = 64;
|
||||
static int const N = 64;
|
||||
static int const K = 8;
|
||||
|
||||
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<K, N, M> >
|
||||
DgemmTraits;
|
||||
|
||||
typedef test::Layout::Span Span;
|
||||
test::Layout::SpanVector dst_layout;
|
||||
test::Layout::SpanVector src_layout;
|
||||
|
||||
// define the source layout
|
||||
src_layout.push_back(Span(0, K));
|
||||
src_layout.push_back(Span(1, M));
|
||||
|
||||
typedef VerifyDataMovement<
|
||||
cutlass::Shape<1, M, K, 1>,
|
||||
double,
|
||||
double
|
||||
> VerifyDataMovement_t;
|
||||
|
||||
VerifyDataMovement_t testbed(src_layout);
|
||||
|
||||
test::Gemm_load_global_store_shared_b< DgemmTraits ><<<
|
||||
dim3(1,1,1),
|
||||
dim3(DgemmTraits::kThreads, 1)
|
||||
>>>(
|
||||
testbed.destination.device_data(),
|
||||
testbed.source.device_data(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
M
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
ASSERT_EQ(result, cudaSuccess) << "\nCUDA kernel launch error: " << cudaGetErrorString(result)
|
||||
<< "\n";
|
||||
|
||||
// define the destination layout
|
||||
dst_layout.push_back(Span(0, K));
|
||||
dst_layout.push_back(Span(1, M));
|
||||
|
||||
EXPECT_TRUE(testbed.verify(dst_layout));
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
}
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
#include "tools/util/type_traits.h"
|
||||
|
||||
#include "tools/util/reference/host/gemm.h"
|
||||
#include "tools/util/reference/device/gemm.h"
|
||||
#include "tools/util/reference/host/tensor_elementwise.h"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -247,6 +248,9 @@ struct GemmTestbed {
|
||||
/// Reference result computed on the host
|
||||
HostMatrixC ref_host;
|
||||
|
||||
/// Reference result computed on the device
|
||||
HostMatrixC ref_device;
|
||||
|
||||
/// Reference result computed with cublas
|
||||
HostMatrixC ref_cublas;
|
||||
|
||||
@ -262,6 +266,9 @@ struct GemmTestbed {
|
||||
/// batch count
|
||||
int batch_count;
|
||||
|
||||
/// partitionK count
|
||||
int partitionK_count;
|
||||
|
||||
/// distance between A[i] and A[i+1] for strided batched gemm
|
||||
long long int batch_stride_A;
|
||||
|
||||
@ -308,6 +315,7 @@ struct GemmTestbed {
|
||||
beta(beta_),
|
||||
algorithm(algorithm_),
|
||||
batch_count(1),
|
||||
partitionK_count(1),
|
||||
batch_stride_A(static_cast<long long int>(0)),
|
||||
batch_stride_B(static_cast<long long int>(0)),
|
||||
batch_stride_C(static_cast<long long int>(0)) {
|
||||
@ -320,6 +328,7 @@ struct GemmTestbed {
|
||||
resize(B, K_, N_, layout_b);
|
||||
resize(C_initial, M_, N_, layout_c);
|
||||
resize(ref_host, M_, N_, layout_c);
|
||||
resize(ref_device, M_, N_, layout_c);
|
||||
resize(ref_cublas, M_, N_, layout_c);
|
||||
resize(computed, M_, N_, layout_c);
|
||||
}
|
||||
@ -345,6 +354,7 @@ struct GemmTestbed {
|
||||
beta(beta_),
|
||||
algorithm(algorithm_),
|
||||
batch_count(1),
|
||||
partitionK_count(1),
|
||||
batch_stride_A(static_cast<long long int>(0)),
|
||||
batch_stride_B(static_cast<long long int>(0)),
|
||||
batch_stride_C(static_cast<long long int>(0)) {
|
||||
@ -353,6 +363,7 @@ struct GemmTestbed {
|
||||
resize(B, K_ * batch_count, N_, layout_b);
|
||||
resize(C_initial, M_, N_ * batch_count, layout_c);
|
||||
resize(ref_host, M_, N_ * batch_count, layout_c);
|
||||
resize(ref_device, M_, N_ * batch_count, layout_c);
|
||||
resize(ref_cublas, M_, N_ * batch_count, layout_c);
|
||||
resize(computed, M_, N_ * batch_count, layout_c);
|
||||
}
|
||||
@ -377,6 +388,7 @@ struct GemmTestbed {
|
||||
beta(beta_),
|
||||
algorithm(algorithm_),
|
||||
batch_count(1),
|
||||
partitionK_count(1),
|
||||
batch_stride_A(static_cast<long long int>(0)),
|
||||
batch_stride_B(static_cast<long long int>(0)),
|
||||
batch_stride_C(static_cast<long long int>(0)) {
|
||||
@ -389,6 +401,7 @@ struct GemmTestbed {
|
||||
resize(B, K_, N_, layout_b, ldb);
|
||||
resize(C_initial, M_, N_, layout_c, ldc);
|
||||
resize(ref_host, M_, N_, layout_c, ldc);
|
||||
resize(ref_device, M_, N_, layout_c, ldc);
|
||||
resize(ref_cublas, M_, N_, layout_c, ldc);
|
||||
resize(computed, M_, N_, layout_c, ldc);
|
||||
}
|
||||
@ -414,6 +427,7 @@ struct GemmTestbed {
|
||||
beta(beta_),
|
||||
algorithm(algorithm_),
|
||||
batch_count(1),
|
||||
partitionK_count(1),
|
||||
batch_stride_A(static_cast<long long int>(0)),
|
||||
batch_stride_B(static_cast<long long int>(0)),
|
||||
batch_stride_C(static_cast<long long int>(0)) {
|
||||
@ -422,6 +436,7 @@ struct GemmTestbed {
|
||||
resize(B, K_ * batch_count, N_, layout_b);
|
||||
resize(C_initial, M_, N_ * batch_count, layout_c);
|
||||
resize(ref_host, M_, N_ * batch_count, layout_c);
|
||||
resize(ref_device, M_, N_ * batch_count, layout_c);
|
||||
resize(ref_cublas, M_, N_ * batch_count, layout_c);
|
||||
resize(computed, M_, N_ * batch_count, layout_c);
|
||||
}
|
||||
@ -446,7 +461,8 @@ struct GemmTestbed {
|
||||
alpha(alpha_),
|
||||
beta(beta_),
|
||||
algorithm(algorithm_),
|
||||
batch_count(batch_count_) {
|
||||
batch_count(batch_count_),
|
||||
partitionK_count(1) {
|
||||
|
||||
status = cublasCreate(&handle);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
@ -457,6 +473,7 @@ struct GemmTestbed {
|
||||
resize(B, K_ * batch_count, N_, layout_b);
|
||||
resize(C_initial, M_, N_ * batch_count, layout_c);
|
||||
resize(ref_host, M_, N_ * batch_count, layout_c);
|
||||
resize(ref_device, M_, N_ * batch_count, layout_c);
|
||||
resize(ref_cublas, M_, N_ * batch_count, layout_c);
|
||||
resize(computed, M_, N_ * batch_count, layout_c);
|
||||
|
||||
@ -465,6 +482,50 @@ struct GemmTestbed {
|
||||
batch_stride_C = M_ * N_;
|
||||
}
|
||||
|
||||
/// Constructs a workspace for verifying partitionedK GEMM, assumes
|
||||
/// dense packing.
|
||||
/// in partitionedK GEMM, the K is partitioned by partitionK_size
|
||||
/// each partition is of the same size, except for the last partition
|
||||
/// each partition, except for the last one, is of size K / partitionK_count
|
||||
/// if K is not divisible by partitionK_size, the last partitionK = K % partitionK_count + K / partitionK_count
|
||||
GemmTestbed(int M_,
|
||||
int N_,
|
||||
std::pair<int, int> K_pair_, /*(k, partitionK_count)*/
|
||||
cublasOperation_t layout_a,
|
||||
cublasOperation_t layout_b,
|
||||
Scalar alpha_ = Scalar(1),
|
||||
Scalar beta_ = Scalar(0),
|
||||
cublasGemmAlgo_t algorithm_ = CUBLAS_GEMM_DEFAULT,
|
||||
cublasOperation_t layout_c = CUBLAS_OP_N)
|
||||
: problem_size(K_pair_.first, N_, M_, 1),
|
||||
layout_A(layout_a),
|
||||
layout_B(layout_b),
|
||||
alpha(alpha_),
|
||||
beta(beta_),
|
||||
algorithm(algorithm_),
|
||||
batch_count(1),
|
||||
partitionK_count(K_pair_.second) {
|
||||
|
||||
status = cublasCreate(&handle);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
throw cutlass::cuda_exception("Failed to create CUBLAS handle");
|
||||
}
|
||||
resize(A, M_, K_pair_.first, layout_a);
|
||||
resize(B, K_pair_.first, N_, layout_b);
|
||||
resize(C_initial, M_, N_ * partitionK_count, layout_c);
|
||||
resize(ref_host, M_, N_ * partitionK_count, layout_c);
|
||||
resize(ref_device, M_, N_ * partitionK_count, layout_c);
|
||||
resize(ref_cublas, M_, N_ * partitionK_count, layout_c);
|
||||
resize(computed, M_, N_ * partitionK_count, layout_c);
|
||||
|
||||
// we can use a combination of batched stried gemm and regular gemm
|
||||
// to simulation partitionedK, which is what we will do for reference code
|
||||
int partitionK_size = K() / partitionK_count;
|
||||
batch_stride_A = (layout_a == CUBLAS_OP_N) ? M_ * partitionK_size : partitionK_size;
|
||||
batch_stride_B = (layout_b == CUBLAS_OP_N) ? partitionK_size : partitionK_size * N_;
|
||||
batch_stride_C = M_ * N_;
|
||||
}
|
||||
|
||||
/// Destructs the GEMM testbed
|
||||
~GemmTestbed() {
|
||||
if (status != CUBLAS_STATUS_NOT_INITIALIZED) {
|
||||
@ -504,7 +565,14 @@ struct GemmTestbed {
|
||||
|
||||
/// Returns the number of flops implied by the computation (1 multiply-accumulate = 2 flops)
|
||||
uint64_t flops() const {
|
||||
return uint64_t(batch_count) * uint64_t(M()) * uint64_t(N()) * uint64_t(K()) * 2ULL;
|
||||
if (partitionK_count == 1) {
|
||||
return uint64_t(batch_count) * uint64_t(M()) * uint64_t(N()) * uint64_t(K()) * 2ULL;
|
||||
}
|
||||
else {
|
||||
int partitionK_size = K() / partitionK_count;
|
||||
return (uint64_t(partitionK_count - 1) * uint64_t(batch_count) * uint64_t(M()) * uint64_t(N()) * uint64_t(partitionK_size) * 2ULL)
|
||||
+ (uint64_t(batch_count) * uint64_t(M()) * uint64_t(N()) * uint64_t(K() - partitionK_size * (partitionK_count - 1)) * 2ULL);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the speed of the computation in GFLOPs/s
|
||||
@ -555,14 +623,15 @@ struct GemmTestbed {
|
||||
// Initialize the source matrix with a uniform distribution
|
||||
cutlass::Distribution dist;
|
||||
dist.set_uniform(-8, 8);
|
||||
|
||||
|
||||
cutlass::reference::host::TensorInitialize(A.host_view(), seed, dist);
|
||||
cutlass::reference::host::TensorInitialize(B.host_view(), seed + 11, dist);
|
||||
cutlass::reference::host::TensorInitialize(C_initial.host_view(), seed + 13, dist);
|
||||
|
||||
|
||||
A.sync_device();
|
||||
B.sync_device();
|
||||
C_initial.sync_device();
|
||||
|
||||
}
|
||||
|
||||
/// Initializes binary data
|
||||
@ -585,56 +654,121 @@ struct GemmTestbed {
|
||||
/// Computes the matrix product on the host
|
||||
void compute_host() {
|
||||
ref_host.fill(C_initial);
|
||||
|
||||
cutlass::reference::host::Gemm(problem_size, alpha, A.host_ref(), B.host_ref(), beta, ref_host.host_ref(), Accumulator(0));
|
||||
}
|
||||
|
||||
/// Compute the matrix product using the device-side reference
|
||||
void compute_device_reference() {
|
||||
ref_device.fill(C_initial);
|
||||
cutlass::reference::device::Gemm(
|
||||
problem_size,
|
||||
cutlass::TypeTraits<Scalar>::to_device(alpha),
|
||||
A.device_ref(),
|
||||
B.device_ref(),
|
||||
cutlass::TypeTraits<Scalar>::to_device(beta),
|
||||
ref_device.device_ref(),
|
||||
cutlass::TypeTraits<Accumulator>::to_device(0)
|
||||
);
|
||||
}
|
||||
|
||||
/// Excutes an equivalent GEMM using cuBLAS
|
||||
bool execute_cublas() {
|
||||
if (batch_count == 1) {
|
||||
status = cublasGemmEx(handle,
|
||||
layout_a(),
|
||||
layout_b(),
|
||||
M(),
|
||||
N(),
|
||||
K(),
|
||||
&alpha,
|
||||
ptr_A(),
|
||||
cutlass::TypeTraits<AType>::cublas_type,
|
||||
lda(),
|
||||
ptr_B(),
|
||||
cutlass::TypeTraits<BType>::cublas_type,
|
||||
ldb(),
|
||||
&beta,
|
||||
ref_cublas.device_data(),
|
||||
cutlass::TypeTraits<CType>::cublas_type,
|
||||
ldc(),
|
||||
cutlass::TypeTraits<Accumulator>::cublas_type,
|
||||
algorithm);
|
||||
if (partitionK_count == 1) {
|
||||
if (batch_count == 1) {
|
||||
status = cublasGemmEx(handle,
|
||||
layout_a(),
|
||||
layout_b(),
|
||||
M(),
|
||||
N(),
|
||||
K(),
|
||||
&alpha,
|
||||
ptr_A(),
|
||||
cutlass::TypeTraits<AType>::cublas_type,
|
||||
lda(),
|
||||
ptr_B(),
|
||||
cutlass::TypeTraits<BType>::cublas_type,
|
||||
ldb(),
|
||||
&beta,
|
||||
ref_cublas.device_data(),
|
||||
cutlass::TypeTraits<CType>::cublas_type,
|
||||
ldc(),
|
||||
cutlass::TypeTraits<Accumulator>::cublas_type,
|
||||
algorithm);
|
||||
|
||||
return status == CUBLAS_STATUS_SUCCESS;
|
||||
} else {
|
||||
// call strided batched gemm
|
||||
return status == CUBLAS_STATUS_SUCCESS;
|
||||
}
|
||||
else {
|
||||
// call strided batched gemm
|
||||
status = cublasGemmStridedBatchedTemplate(handle,
|
||||
layout_a(),
|
||||
layout_b(),
|
||||
M(),
|
||||
N(),
|
||||
K(),
|
||||
&alpha,
|
||||
ptr_A(),
|
||||
lda(),
|
||||
batch_stride_A,
|
||||
ptr_B(),
|
||||
ldb(),
|
||||
batch_stride_B,
|
||||
&beta,
|
||||
ref_cublas.device_data(),
|
||||
ldc(),
|
||||
batch_stride_C,
|
||||
batch_count);
|
||||
|
||||
return status == CUBLAS_STATUS_SUCCESS;
|
||||
}
|
||||
}
|
||||
else {
|
||||
assert(batch_count == 1);
|
||||
//the last batch is of a different K
|
||||
//first call strided batched gemm
|
||||
|
||||
int partitionK_size = K() / partitionK_count;
|
||||
//int lastK_size = (K() % partitionK_size) + partitionK_size;
|
||||
int lastK_size = K() - partitionK_size * (partitionK_count - 1);
|
||||
status = cublasGemmStridedBatchedTemplate(handle,
|
||||
layout_a(),
|
||||
layout_b(),
|
||||
M(),
|
||||
N(),
|
||||
K(),
|
||||
&alpha,
|
||||
ptr_A(),
|
||||
lda(),
|
||||
batch_stride_A,
|
||||
ptr_B(),
|
||||
ldb(),
|
||||
batch_stride_B,
|
||||
&beta,
|
||||
ref_cublas.device_data(),
|
||||
ldc(),
|
||||
batch_stride_C,
|
||||
batch_count);
|
||||
|
||||
layout_a(),
|
||||
layout_b(),
|
||||
M(),
|
||||
N(),
|
||||
partitionK_size,
|
||||
&alpha,
|
||||
ptr_A(),
|
||||
lda(),
|
||||
batch_stride_A,
|
||||
ptr_B(),
|
||||
ldb(),
|
||||
batch_stride_B,
|
||||
&beta,
|
||||
ref_cublas.device_data(),
|
||||
ldc(),
|
||||
batch_stride_C,
|
||||
partitionK_count - 1);
|
||||
//then call gemm for the last batch
|
||||
status = cublasGemmEx(handle,
|
||||
layout_a(),
|
||||
layout_b(),
|
||||
M(),
|
||||
N(),
|
||||
lastK_size,
|
||||
&alpha,
|
||||
ptr_A() + (partitionK_count - 1) * batch_stride_A,
|
||||
cutlass::TypeTraits<AType>::cublas_type,
|
||||
lda(),
|
||||
ptr_B() + (partitionK_count - 1) * batch_stride_B,
|
||||
cutlass::TypeTraits<BType>::cublas_type,
|
||||
ldb(),
|
||||
&beta,
|
||||
ref_cublas.device_data() + (partitionK_count - 1) * batch_stride_C,
|
||||
cutlass::TypeTraits<CType>::cublas_type,
|
||||
ldc(),
|
||||
cutlass::TypeTraits<Accumulator>::cublas_type,
|
||||
algorithm);
|
||||
return status == CUBLAS_STATUS_SUCCESS;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -787,6 +921,24 @@ struct GemmTestbed {
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Verifies the reference implementation with cuBLAS
|
||||
bool verify_reference_with_cublas(bool save_on_error = true, bool always_print = false) {
|
||||
|
||||
compute_device_reference();
|
||||
ref_device.sync_host();
|
||||
|
||||
compute_cublas();
|
||||
ref_cublas.sync_host();
|
||||
|
||||
bool passed = ref_device.bit_equals(ref_cublas);
|
||||
|
||||
if ((!passed && save_on_error) || always_print) {
|
||||
save_workspace(ref_device, ref_cublas);
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Verifies with host-side and device-side computations
|
||||
bool verify_with_all() {
|
||||
bool passed = true;
|
||||
@ -917,4 +1069,44 @@ template<> inline cublasStatus_t GemmTestbed<cutlass::half_t, cutlass::half_t, c
|
||||
batchCount);
|
||||
}
|
||||
|
||||
template<> inline cublasStatus_t GemmTestbed<cutlass::half_t, cutlass::half_t, cutlass::half_t, float, float>::cublasGemmStridedBatchedTemplate(cublasHandle_t handle,
|
||||
cublasOperation_t transa,
|
||||
cublasOperation_t transb,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
const float *alpha,
|
||||
const half *ptr_A,
|
||||
int lda,
|
||||
long long int stride_A,
|
||||
const half *ptr_B,
|
||||
int ldb,
|
||||
long long int stride_B,
|
||||
const float *beta,
|
||||
half *ptr_C,
|
||||
int ldc,
|
||||
long long int stride_C,
|
||||
int batchCount) {
|
||||
return cublasGemmStridedBatchedEx(handle,
|
||||
transa,
|
||||
transb,
|
||||
M, N, K,
|
||||
alpha,
|
||||
ptr_A,
|
||||
cutlass::TypeTraits<cutlass::half_t>::cublas_type,
|
||||
lda,
|
||||
stride_A,
|
||||
ptr_B,
|
||||
cutlass::TypeTraits<cutlass::half_t>::cublas_type,
|
||||
ldb,
|
||||
stride_B,
|
||||
beta,
|
||||
ptr_C,
|
||||
cutlass::TypeTraits<cutlass::half_t>::cublas_type,
|
||||
ldc,
|
||||
stride_C,
|
||||
batchCount,
|
||||
cutlass::TypeTraits<float>::cublas_type,
|
||||
CUBLAS_GEMM_DEFAULT);
|
||||
}
|
||||
} // namespace test
|
||||
|
||||
@ -29,6 +29,8 @@
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x16, hgemm_128x128x16_nt) {
|
||||
@ -326,4 +328,5 @@ TEST(Hgemm_128x128x16, hgemm_124x126x32_ragged_alpha2_beta1_nt) {
|
||||
run_gemm<HgemmTraits>(124, 126, 32, cutlass::half_t(2), cutlass::half_t(1));
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#endif
|
||||
|
||||
|
||||
@ -29,6 +29,8 @@
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x128x8, hgemm_128x128x1_nt) {
|
||||
@ -384,5 +386,5 @@ TEST(Hgemm_128x128x8, hgemm_124x126x32_ragged_alpha2_beta1_nt) {
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x32x8, hgemm_128x32x1_nt) {
|
||||
@ -312,3 +313,5 @@ TEST(Hgemm_128x32x8, hgemm_256x64x16_tt) {
|
||||
run_gemm<HgemmTraits>(256, 64, 16);
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#endif
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Hgemm_128x64x8, hgemm_128x64x1_nt) {
|
||||
@ -312,3 +313,5 @@ TEST(Hgemm_128x64x8, hgemm_256x128x16_tt) {
|
||||
run_gemm<HgemmTraits>(256, 128, 16);
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#endif
|
||||
|
||||
|
||||
378
tools/test/unit/gemm/partitionedK_sgemm_128x128x8.cu
Normal file
378
tools/test/unit/gemm/partitionedK_sgemm_128x128x8.cu
Normal file
@ -0,0 +1,378 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/sgemm_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x100x8_nn) {
|
||||
/*
|
||||
for example
|
||||
partitionedK sgemm, m = 128, n = 256, overall_K = 100, partitionK_count = 8
|
||||
for the first 7 partition k = overall_k / partitionK_count = 12
|
||||
for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16
|
||||
*/
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 100;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x175x8_nn) {
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 175;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_10x12x20x3_nn) {
|
||||
|
||||
int m = 10;
|
||||
int n = 12;
|
||||
int overall_k = 20;
|
||||
int partitionK_count = 3;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_10x12x60x8_nn) {
|
||||
|
||||
int m = 10;
|
||||
int n = 12;
|
||||
int overall_k = 60;
|
||||
int partitionK_count = 8;
|
||||
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x100x4_nn) {
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 100;
|
||||
int partitionK_count = 4;
|
||||
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x100x8_nt) {
|
||||
/*
|
||||
for example
|
||||
partitionedK sgemm, m = 128, n = 256, overall_K = 100, partitionK_count = 8
|
||||
for the first 7 partition k = overall_k / partitionK_count = 12
|
||||
for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16
|
||||
*/
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 100;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x175x8_nt) {
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 175;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_10x12x20x3_nt) {
|
||||
|
||||
int m = 10;
|
||||
int n = 12;
|
||||
int overall_k = 20;
|
||||
int partitionK_count = 3;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_10x12x60x8_nt) {
|
||||
|
||||
int m = 10;
|
||||
int n = 12;
|
||||
int overall_k = 60;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x100x4_nt) {
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 100;
|
||||
int partitionK_count = 4;
|
||||
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x100x8_tn) {
|
||||
/*
|
||||
for example
|
||||
partitionedK sgemm, m = 128, n = 256, overall_K = 100, partitionK_count = 8
|
||||
for the first 7 partition k = overall_k / partitionK_count = 12
|
||||
for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16
|
||||
*/
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 100;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x175x8_tn) {
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 175;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_10x12x20x3_tn) {
|
||||
|
||||
int m = 10;
|
||||
int n = 12;
|
||||
int overall_k = 20;
|
||||
int partitionK_count = 3;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_10x12x60x8_tn) {
|
||||
|
||||
int m = 10;
|
||||
int n = 12;
|
||||
int overall_k = 60;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x100x4_tn) {
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 100;
|
||||
int partitionK_count = 4;
|
||||
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x100x8_tt) {
|
||||
/*
|
||||
for example
|
||||
partitionedK sgemm, m = 128, n = 256, overall_K = 100, partitionK_count = 8
|
||||
for the first 7 partition k = overall_k / partitionK_count = 12
|
||||
for the last partition last_k = overall_k - (partitionK_count - 1) * k = 16
|
||||
*/
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 100;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x175x8_tt) {
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 175;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_10x12x20x3_tt) {
|
||||
|
||||
int m = 10;
|
||||
int n = 12;
|
||||
int overall_k = 20;
|
||||
int partitionK_count = 3;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_10x12x60x8_tt) {
|
||||
|
||||
int m = 10;
|
||||
int n = 12;
|
||||
int overall_k = 60;
|
||||
int partitionK_count = 8;
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Sgemm_partitionedK_128x128x8, sgemm_128x256x100x4_tt) {
|
||||
|
||||
int m = 128;
|
||||
int n = 256;
|
||||
int overall_k = 100;
|
||||
int partitionK_count = 4;
|
||||
|
||||
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
run_partitioned_k_gemm<SgemmTraits>(m, n, overall_k, partitionK_count);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -25,8 +25,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "cutlass/gemm/device_gemm.h"
|
||||
#include "cutlass/gemm/device_gemm_traits.h"
|
||||
|
||||
template <typename GemmTraits_>
|
||||
static void run_gemm(
|
||||
int m,
|
||||
@ -36,9 +40,9 @@ static void run_gemm(
|
||||
int ldb,
|
||||
int ldc,
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type alpha =
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(1),
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(1.0f),
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type beta =
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(0)) {
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(0.0f)) {
|
||||
|
||||
typedef typename GemmTraits_::KernelClass Gemm;
|
||||
typename Gemm::Params params;
|
||||
@ -69,8 +73,10 @@ static void run_gemm(
|
||||
|
||||
if (testbed.has_cublas_support()) {
|
||||
EXPECT_TRUE(testbed.verify_host_with_cublas());
|
||||
EXPECT_TRUE(testbed.verify_reference_with_cublas());
|
||||
}
|
||||
|
||||
|
||||
params.initialize(testbed.M(),
|
||||
testbed.N(),
|
||||
testbed.K(),
|
||||
@ -137,6 +143,7 @@ static void run_gemm(
|
||||
|
||||
if (testbed.has_cublas_support()) {
|
||||
EXPECT_TRUE(testbed.verify_host_with_cublas());
|
||||
EXPECT_TRUE(testbed.verify_reference_with_cublas());
|
||||
}
|
||||
|
||||
params.initialize(testbed.M(),
|
||||
@ -175,9 +182,9 @@ static void run_batched_strided_gemm(
|
||||
int k,
|
||||
int batch_count,
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type alpha =
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(1),
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(1.0f),
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type beta =
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(0)) {
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(0.0f)) {
|
||||
//typedef cutlass::gemm::Gemm<GemmTraits_> Gemm;
|
||||
typedef typename GemmTraits_::KernelClass Gemm;
|
||||
typename Gemm::Params params;
|
||||
@ -242,3 +249,153 @@ static void run_batched_strided_gemm(
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_, typename ReductionTraits_>
|
||||
static void run_splitK_gemm(int m,
|
||||
int n,
|
||||
int k,
|
||||
typename test::GemmTestbedTraits<typename ReductionTraits_::ScalarAlphaBeta>::host_type alpha =
|
||||
typename test::GemmTestbedTraits<typename ReductionTraits_::ScalarAlphaBeta>::host_type(1.0f),
|
||||
typename test::GemmTestbedTraits<typename ReductionTraits_::ScalarAlphaBeta>::host_type beta =
|
||||
typename test::GemmTestbedTraits<typename ReductionTraits_::ScalarAlphaBeta>::host_type(0.0f),
|
||||
bool use_host_reference = false){
|
||||
|
||||
test::GemmTestbed<
|
||||
typename test::GemmTestbedTraits<
|
||||
typename GemmTraits_::GemmConfig::ScalarA>::host_type, // AType
|
||||
typename test::GemmTestbedTraits<
|
||||
typename GemmTraits_::GemmConfig::ScalarB>::host_type, // BType
|
||||
typename test::GemmTestbedTraits<
|
||||
typename ReductionTraits_::ScalarC>::host_type, // CType
|
||||
typename test::GemmTestbedTraits<
|
||||
typename GemmTraits_::GemmConfig::ScalarD>::host_type, // Workspace Accumulator
|
||||
typename test::GemmTestbedTraits<typename ReductionTraits_::ScalarAlphaBeta>::host_type // Scalar
|
||||
>
|
||||
testbed(m,
|
||||
n,
|
||||
k,
|
||||
test::convert(GemmTraits_::kLayoutA),
|
||||
test::convert(GemmTraits_::kLayoutB),
|
||||
alpha,
|
||||
beta);
|
||||
|
||||
testbed.initialize();
|
||||
|
||||
// create a device gemm
|
||||
typedef cutlass::gemm::SplitkPIGemmTraits<GemmTraits_, ReductionTraits_> deviceGemmTraits;
|
||||
typedef typename deviceGemmTraits::KernelClass deviceGemm;
|
||||
typename deviceGemm::Params deviceGemmParams(testbed.M(), testbed.N(), testbed.K());
|
||||
|
||||
// query if workspace is needed
|
||||
int workspace_size = deviceGemmParams.required_workspace_memory_in_byte();
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::GemmConfig::ScalarD>::device_type
|
||||
*workspace_ptr = 0;
|
||||
if (workspace_size != 0) {
|
||||
cudaError_t workspace_err = cudaMalloc(&workspace_ptr, workspace_size);
|
||||
ASSERT_EQ(workspace_err, cudaSuccess) << "\nCUDA workspace malloc error: " << cudaGetErrorString(workspace_err)
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
deviceGemmParams.initialize(testbed.alpha,
|
||||
testbed.ptr_A(),
|
||||
testbed.lda(),
|
||||
testbed.ptr_B(),
|
||||
testbed.ldb(),
|
||||
testbed.beta,
|
||||
testbed.ptr_C_initial(),
|
||||
testbed.ldc(),
|
||||
testbed.ptr_computed(),
|
||||
testbed.ldc(),
|
||||
workspace_ptr);
|
||||
|
||||
|
||||
deviceGemm::launch(deviceGemmParams);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
ASSERT_EQ(result, cudaSuccess) << "\nCUDA kernel launch error: " << cudaGetErrorString(result)
|
||||
<< "\n";
|
||||
|
||||
if (workspace_size != 0) {
|
||||
cudaError_t workspace_err = cudaFree(workspace_ptr);
|
||||
ASSERT_EQ(workspace_err, cudaSuccess) << "\nCUDA workspace free error: " << cudaGetErrorString(workspace_err)
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
if (use_host_reference == true || testbed.has_cublas_support() == false) {
|
||||
ASSERT_TRUE(testbed.verify_with_host());
|
||||
}
|
||||
else {
|
||||
ASSERT_TRUE(testbed.verify_with_cublas());
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_>
|
||||
static void run_partitioned_k_gemm(
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int partitionK_count,
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type alpha =
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(1.0f),
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type beta =
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type(0.0f)) {
|
||||
//typedef cutlass::gemm::Gemm<GemmTraits_> Gemm;
|
||||
typedef typename GemmTraits_::KernelClass Gemm;
|
||||
typename Gemm::Params params;
|
||||
test::GemmTestbed<
|
||||
typename test::GemmTestbedTraits<
|
||||
typename GemmTraits_::GemmConfig::ScalarA>::host_type, // AType
|
||||
typename test::GemmTestbedTraits<
|
||||
typename GemmTraits_::GemmConfig::ScalarB>::host_type, // BType
|
||||
typename test::GemmTestbedTraits<
|
||||
typename GemmTraits_::Epilogue::ScalarC>::host_type, // CType
|
||||
typename test::GemmTestbedTraits<
|
||||
typename GemmTraits_::Epilogue::Accumulators::Element>::host_type, // Accumulator
|
||||
typename test::GemmTestbedTraits<typename GemmTraits_::Epilogue::Scalar>::host_type // Scalar
|
||||
>
|
||||
testbed(m,
|
||||
n,
|
||||
std::make_pair(k, partitionK_count),
|
||||
test::convert(GemmTraits_::kLayoutA),
|
||||
test::convert(GemmTraits_::kLayoutB),
|
||||
alpha,
|
||||
beta);
|
||||
|
||||
testbed.initialize();
|
||||
|
||||
// host support is not implemented for strided batched gemm
|
||||
// if (testbed.has_cublas_support()) {
|
||||
// EXPECT_TRUE(testbed.verify_host_with_cublas());
|
||||
//}
|
||||
|
||||
params.initialize(testbed.M(),
|
||||
testbed.N(),
|
||||
testbed.K(),
|
||||
testbed.alpha,
|
||||
testbed.ptr_A(),
|
||||
testbed.lda(),
|
||||
testbed.ptr_B(),
|
||||
testbed.ldb(),
|
||||
testbed.beta,
|
||||
testbed.ptr_C_initial(),
|
||||
testbed.ldc(),
|
||||
testbed.ptr_computed(),
|
||||
testbed.ldc(),
|
||||
partitionK_count);
|
||||
|
||||
Gemm::launch(params);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
ASSERT_EQ(result, cudaSuccess) << "\nCUDA kernel launch error: " << cudaGetErrorString(result)
|
||||
<< "\n";
|
||||
|
||||
if (testbed.has_cublas_support()) {
|
||||
ASSERT_TRUE(testbed.verify_with_cublas());
|
||||
}
|
||||
else {
|
||||
// ASSERT_TRUE(testbed.verify_with_host());
|
||||
ASSERT_TRUE(false) << "host support is not implemented for strided batched gemm" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
247
tools/test/unit/gemm/splitK_dgemm.cu
Normal file
247
tools/test/unit/gemm/splitK_dgemm.cu
Normal file
@ -0,0 +1,247 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/dgemm_traits.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x512_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
DgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<DgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x512_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
DgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<DgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x512_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
DgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<DgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x512_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
DgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<DgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x500_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
DgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<DgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x500_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
DgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<DgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x500_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
DgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<DgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_dgemm_128x128x8_splits16, dgemm_128x256x500_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::DgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
DgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<DgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
579
tools/test/unit/gemm/splitK_fp16_sgemm_fp16.cu
Normal file
579
tools/test/unit/gemm/splitK_fp16_sgemm_fp16.cu
Normal file
@ -0,0 +1,579 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/fp16_sgemm_traits.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/*
|
||||
for fp16_sgemm_fp16 A, B, C and D are half typed. alpha and beta can be half or float typed.
|
||||
Accumulation is float typed.
|
||||
1. in batched gemm kernel, Ab and Bb are half typed, and pointing to A and B.
|
||||
Cb and Db are float typed, since Db is actually pointing to the workspace memory
|
||||
thus is of the same type with accumulation. Cb is generally ignored since beta is zero. alpha is one.
|
||||
2. in the reduction kernel. Dr = alpha * Reduction(Ar) + beta * Cr. Ar is float typed and pointing to the same
|
||||
workspace memory with Db. Cr is half typed and pointing to C. Dr is half typed and pointing to D.
|
||||
ALPHAr is the same with alpha, BETAr is the same with beta.
|
||||
*/
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x512_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
float, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x512_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
float, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x512_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
float, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x512_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
float, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x500_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
float, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x500_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
float, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x500_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
float, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFloat_128x128x8_splits16, sgemm_128x256x500_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
float, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x512_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
half, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x512_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
half, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x512_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
half, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x512_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
half, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//k = 500
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x500_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
half, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x500_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
half, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x500_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
half, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_fp16_sgemm_fp16_alphabetaFp16_128x128x8_splits16, sgemm_128x256x500_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, /*Ab type*/
|
||||
half, /*Bb type*/
|
||||
float, /*Cb type*/
|
||||
float, /*Db type*/
|
||||
float /*alpha, beta type*/
|
||||
>
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*Ar type*/
|
||||
half, /*Cr type*/
|
||||
half, /*Dr type*/
|
||||
half, /*alpha, beta type*/
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
248
tools/test/unit/gemm/splitK_hgemm.cu
Normal file
248
tools/test/unit/gemm/splitK_hgemm.cu
Normal file
@ -0,0 +1,248 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/hgemm_traits.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x64_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 64;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
HgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<HgemmTraits, BatchedReductionTraits>(m, n, k, 1.0f, 0.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x64_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 64;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
HgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<HgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x64_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 64;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
HgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<HgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x64_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 64;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
HgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<HgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x66_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 66;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
HgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<HgemmTraits, BatchedReductionTraits>(m, n, k, 1.0f, 0.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x66_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 66;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
HgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<HgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x66_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 66;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
HgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<HgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_hgemm_128x128x8_splits16, hgemm_128x256x66_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 66;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::HgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
HgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<HgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
367
tools/test/unit/gemm/splitK_igemm.cu
Normal file
367
tools/test/unit/gemm/splitK_igemm.cu
Normal file
@ -0,0 +1,367 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/igemm_traits.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x128x32_splits16, igemm_128x256x512_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 128, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 2, 1, true /*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x128x32_splits16, igemm_128x256x512_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 128, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 2, 1, true /*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x128x32_splits16, igemm_128x256x512_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 128, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 2, 1, true /*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x128x32_splits16, igemm_128x256x512_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 128, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 2, 1, true /*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x128x32_splits16, igemm_1024x64x4096_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 128, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 1, 0, false /*not use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x128x32_splits16, igemm_1024x64x4096_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 128, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 1, 0, false /*not use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x128x32_splits16, igemm_1024x64x4096_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 128, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 1, 0, false /*not use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x128x32_splits16, igemm_1024x64x4096_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 128, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 1, 0, false /*not use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x32x32_splits16, igemm_1024x64x4096_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 32, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 1, 0, false /*not use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x32x32_splits16, igemm_1024x64x4096_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 32, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 1, 0, false /*not use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x32x32_splits16, igemm_1024x64x4096_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 32, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 1, 0, false /*not use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_igemm_128x32x32_splits16, igemm_1024x64x4096_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched igemm traits*/
|
||||
typedef cutlass::gemm::IgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 32, 128>, int, cutlass::gemm::LinearScaling<int> >
|
||||
IgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<IgemmTraits, BatchedReductionTraits>(m, n, k, 1, 0, false /*not use host reference*/);
|
||||
}
|
||||
355
tools/test/unit/gemm/splitK_sgemm.cu
Normal file
355
tools/test/unit/gemm/splitK_sgemm.cu
Normal file
@ -0,0 +1,355 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/sgemm_traits.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x512_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x512_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x512_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x512_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x500_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x500_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x500_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_128x256x500_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 500;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_1024x64x4096_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_1024x64x4096_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_1024x64x4096_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_sgemm_128x128x8_splits16, sgemm_1024x64x4096_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 1024;
|
||||
const int n = 64;
|
||||
const int k = 4096;
|
||||
|
||||
/*batched sgemm traits*/
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<SgemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f);
|
||||
}
|
||||
175
tools/test/unit/gemm/splitK_wmma_gemm.cu
Normal file
175
tools/test/unit/gemm/splitK_wmma_gemm.cu
Normal file
@ -0,0 +1,175 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#if defined(CUTLASS_USE_WMMA_API)
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/wmma_gemm_traits.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_wmma_gemm_16x16x32_splits16, wmma_gemm_128x256x512_nn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched wmma gemm traits*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<WmmaGemmTraits, BatchedReductionTraits>(m, n, k, 2.0f, 1.0f, true/*use host reference*/);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_wmma_gemm_16x16x32_splits16, wmma_gemm_128x256x512_nt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched wmma gemm traits*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<WmmaGemmTraits, BatchedReductionTraits>(m, n, k, 1.0f, 0.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_wmma_gemm_16x16x32_splits16, wmma_gemm_128x256x512_tn) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched wmma gemm traits*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<WmmaGemmTraits, BatchedReductionTraits>(m, n, k, 1.0f, 0.0f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SplitK_wmma_gemm_16x16x32_splits16, wmma_gemm_128x256x512_tt) {
|
||||
const int splits_count = 16;
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int k = 512;
|
||||
|
||||
/*batched wmma gemm traits*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
/*batched reduction traits*/
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
splits_count,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits;
|
||||
|
||||
run_splitK_gemm<WmmaGemmTraits, BatchedReductionTraits>(m, n, k, 1.0f, 0.0f);
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -53,6 +53,7 @@ TEST(WmmaGemm_16x16x32_f16, wmma_gemm_16x16x16_nn) {
|
||||
run_gemm<WmmaGemmTraits>(16, 16, 16);
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_f16, wmma_gemm_16x16x32_nn) {
|
||||
@ -367,7 +368,5 @@ TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tn) {
|
||||
run_gemm<WmmaGemmTraits>(256, 256, 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
|
||||
155
tools/test/unit/gemm/wmma_gemm_non_multiple16.cu
Normal file
155
tools/test/unit/gemm/wmma_gemm_non_multiple16.cu
Normal file
@ -0,0 +1,155 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#if defined(CUTLASS_USE_WMMA_API)
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/wmma_gemm_traits.h"
|
||||
#include "tools/test/unit/gemm/gemm_testbed.h"
|
||||
#include "tools/test/unit/gemm/run_gemm.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_f16, wmma_gemm_36x36x16_nn) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(36, 36, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_f16, wmma_gemm_36x36x16_nt) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(36, 36, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_f16, wmma_gemm_36x36x16_tn) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(36, 36, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(WmmaGemm_16x16x32_f16, wmma_gemm_36x36x16_tt) {
|
||||
/*
|
||||
this wmmaTraits requires leading dim to be divisible by 4
|
||||
*/
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::Shape<32, 16, 16>,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
cutlass::gemm::LinearScaling<half>,
|
||||
half,
|
||||
typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
|
||||
typename cutlass::Shape<16, 16, 16>,
|
||||
4, /*kScalarsPerLdgA_*/
|
||||
4, /*kScalarsPerLdgB_*/
|
||||
4, /*KScalarsPerLdsA_*/
|
||||
4, /*KScalarsPerLdsB_*/
|
||||
4 / sizeof(half), /*kScalarsPerLdgCAndStgD_*/
|
||||
4 / sizeof(half), /*kScalarsPerStsD_*/
|
||||
4 / sizeof(half) /*kScalarsPerLdsD_*/
|
||||
>
|
||||
WmmaGemmTraits;
|
||||
|
||||
run_gemm<WmmaGemmTraits>(36, 36, 64);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#endif
|
||||
307
tools/test/unit/reduction/batched_reduction.cu
Normal file
307
tools/test/unit/reduction/batched_reduction.cu
Normal file
@ -0,0 +1,307 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/shape.h"
|
||||
#include "tools/util/host_tensor.h"
|
||||
#include "cutlass/reduction/batched_reduction.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "tools/test/unit/reduction/test_batched_reduction.h"
|
||||
#include "tools/test/unit/reduction/batched_reduction_testbed.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Batched_reduction_float, batched_reduction_128x256x16) {
|
||||
/*
|
||||
The output matrix is 128x256
|
||||
The input matrix is 128x256x16
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 16;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*A*/
|
||||
float, /*C*/
|
||||
float, /*D*/
|
||||
float, /*alpha and beta*/
|
||||
float, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits_16;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_16>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Batched_reduction_double, batched_reduction_128x256x16) {
|
||||
/*
|
||||
D = alpha * Reduction(A) + beta * C
|
||||
The output matrix D is 128x256
|
||||
The input matrix A is 128x256x16
|
||||
The input matrix C is 128x256
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 16;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits_16;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_16>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(Batched_reduction_half, batched_reduction_128x256x16) {
|
||||
/*
|
||||
The output matrix is 128x256
|
||||
The input matrix is 128x256x16
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 16;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits_16;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_16>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Batched_reduction_float, batched_reduction_128x64x80) {
|
||||
/*
|
||||
The output matrix is 128x64
|
||||
The input matrix is 128x64x80
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 64;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 80;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits_80;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_80>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Batched_reduction_double, batched_reduction_128x64x80) {
|
||||
/*
|
||||
The output matrix is 128x64
|
||||
The input matrix is 128x64x80
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 64;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 80;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits_80;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_80>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(Batched_reduction_half, batched_reduction_128x64x80) {
|
||||
/*
|
||||
The output matrix is 128x64
|
||||
The input matrix is 128x64x80
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 64;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 80;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half,
|
||||
half,
|
||||
half,
|
||||
half,
|
||||
half, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits_80;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_80>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Batched_reduction_float_threadShape1, batched_reduction_128x256x90) {
|
||||
/*
|
||||
The output matrix is 128x256
|
||||
The input matrix is 128x256x90
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 90;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*A*/
|
||||
float, /*C*/
|
||||
float, /*D*/
|
||||
float, /*alpha and beta*/
|
||||
float, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 1> >
|
||||
BatchedReductionTraits_16;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_16>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Batched_reduction_double_threadShape1, batched_reduction_128x256x90) {
|
||||
/*
|
||||
The output matrix is 128x256
|
||||
The input matrix is 128x256x90
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 90;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<double, /*A*/
|
||||
double, /*C*/
|
||||
double, /*D*/
|
||||
double, /*alpha and beta*/
|
||||
double, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 1> >
|
||||
BatchedReductionTraits_16;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_16>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(Batched_reduction_half_threadShape1, batched_reduction_128x256x90) {
|
||||
/*
|
||||
The output matrix is 128x256
|
||||
The input matrix is 128x256x90
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 90;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<half, /*A*/
|
||||
half, /*C*/
|
||||
half, /*D*/
|
||||
half, /*alpha and beta*/
|
||||
half, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 1> >
|
||||
BatchedReductionTraits_16;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_16>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
301
tools/test/unit/reduction/batched_reduction_testbed.h
Normal file
301
tools/test/unit/reduction/batched_reduction_testbed.h
Normal file
@ -0,0 +1,301 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Test environment for batched reduction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
#include "tools/util/host_matrix.h"
|
||||
#include "tools/util/host_matrix_view.h"
|
||||
#include "tools/util/host_tensor.h"
|
||||
#include "tools/util/tensor_view_io.h"
|
||||
#include "tools/util/type_traits.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
namespace test {
|
||||
|
||||
inline cublasOperation_t convert(cutlass::MatrixLayout::Kind layout) {
|
||||
switch (layout) {
|
||||
case cutlass::MatrixLayout::kRowMajor:
|
||||
return CUBLAS_OP_T;
|
||||
case cutlass::MatrixLayout::kColumnMajor:
|
||||
return CUBLAS_OP_N;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return CUBLAS_OP_N;
|
||||
}
|
||||
|
||||
inline cutlass::MatrixLayout::Kind convert(cublasOperation_t transform) {
|
||||
switch (transform) {
|
||||
case CUBLAS_OP_T:
|
||||
return cutlass::MatrixLayout::kRowMajor;
|
||||
case CUBLAS_OP_N:
|
||||
return cutlass::MatrixLayout::kColumnMajor;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return cutlass::MatrixLayout::kColumnMajor;
|
||||
}
|
||||
|
||||
/// Testbed for evaluating batched reduction
|
||||
template <
|
||||
typename AType,
|
||||
typename CType,
|
||||
typename DType,
|
||||
typename ScalarAlpha,
|
||||
typename ScalarBeta,
|
||||
typename ScalarAccum,
|
||||
// input matrix depth size to be sumed
|
||||
int ReductionSize
|
||||
>
|
||||
struct BatchedReductionTestbed {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
/// Host tensor for operand C
|
||||
typedef cutlass::HostTensor<AType, 3> HostTensorA;
|
||||
|
||||
/// Host tensor for operand C
|
||||
typedef cutlass::HostMatrix<CType> HostMatrixC;
|
||||
|
||||
/// Host tensor for operand D
|
||||
typedef cutlass::HostMatrix<DType> HostMatrixD;
|
||||
|
||||
/// Generates random elements
|
||||
template <typename T>
|
||||
struct RandomGenerator {
|
||||
RandomGenerator(int seed = -1, bool only_ones_ = false) : only_ones(only_ones_) { srand(seed); }
|
||||
|
||||
T operator()() {
|
||||
if (only_ones) {
|
||||
return T(1);
|
||||
}
|
||||
else {
|
||||
int val = (rand() % 16) - 8;
|
||||
return T(val);
|
||||
}
|
||||
}
|
||||
|
||||
bool only_ones;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct RandomBitGenerator {
|
||||
RandomBitGenerator(int seed = -1) { srand(seed); }
|
||||
|
||||
T operator()() {
|
||||
uint32_t val = 0;
|
||||
for (int i = 0; i < 32; i++) {
|
||||
val |= rand() % 2;
|
||||
val <<= 1;
|
||||
}
|
||||
return T(val);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// input/output number of rows
|
||||
int m;
|
||||
|
||||
/// input/output number of columns
|
||||
int n;
|
||||
|
||||
/// A matrix operand, always column major, no trans
|
||||
HostTensorA A;
|
||||
|
||||
/// C matrix operand, always column major, no trans
|
||||
HostMatrixC C;
|
||||
|
||||
/// D matrix operand, always column major, no trans
|
||||
HostMatrixD D;
|
||||
|
||||
/// Reference
|
||||
cutlass::HostTensor<AType, 3> ref_A;
|
||||
|
||||
///
|
||||
cutlass::HostMatrix<CType> ref_C;
|
||||
|
||||
/// Reference result computed on the host
|
||||
cutlass::HostMatrix<DType> ref_D;
|
||||
|
||||
/// lda
|
||||
int lda;
|
||||
|
||||
/// ldc
|
||||
int ldc;
|
||||
|
||||
/// ldd
|
||||
int ldd;
|
||||
|
||||
/// Linear scalaring factor
|
||||
ScalarAlpha alpha;
|
||||
|
||||
/// Linear scaling factor
|
||||
ScalarBeta beta;
|
||||
|
||||
/// stride between two element that will be sumed
|
||||
long long int reduction_stride;
|
||||
|
||||
//
|
||||
// Static helpers
|
||||
//
|
||||
|
||||
/// Helper to resize a matrix with a given size and layout
|
||||
template <typename T>
|
||||
static void resize(cutlass::HostMatrix<T>& tensor,
|
||||
int rows,
|
||||
int columns,
|
||||
cublasOperation_t layout,
|
||||
int ldm = 0,
|
||||
bool device_backed = true) {
|
||||
|
||||
tensor.resize(cutlass::make_Coord(rows, columns), convert(layout), ldm, device_backed);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void resize(cutlass::HostTensor<T, 3>& tensor,
|
||||
int rows,
|
||||
int columns,
|
||||
int batches,
|
||||
cublasOperation_t layout,
|
||||
int ldm,
|
||||
long long int batch_stride,
|
||||
bool device_backed = true) {
|
||||
assert(CUBLAS_OP_N == layout);
|
||||
//tensor.resize(cutlass::make_Coord(rows, columns), convert(layout), ldm, device_backed);
|
||||
tensor.reset(cutlass::make_Coord(static_cast<int>(batch_stride), ldm, 1), /*stride, slowest moving dim on the left*/
|
||||
cutlass::make_Coord(batches, columns, rows), /*size, slowest moving dim on the left*/
|
||||
device_backed);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
BatchedReductionTestbed(int m_,
|
||||
int n_,
|
||||
int lda_,
|
||||
int ldc_,
|
||||
int ldd_,
|
||||
typename cutlass::TypeTraits<ScalarAlpha>::host_type alpha_ =
|
||||
typename cutlass::TypeTraits<ScalarAlpha>::host_type(2),
|
||||
typename cutlass::TypeTraits<ScalarAlpha>::host_type beta_ =
|
||||
typename cutlass::TypeTraits<ScalarAlpha>::host_type(3))
|
||||
: m(m_),
|
||||
n(n_),
|
||||
lda(lda_),
|
||||
ldc(ldc_),
|
||||
ldd(ldd_),
|
||||
alpha(alpha_),
|
||||
beta(beta_),
|
||||
reduction_stride(ldc_ * n_) {
|
||||
/// column major, batch along rows
|
||||
resize(A, m_, n_, ReductionSize, CUBLAS_OP_N, lda_, reduction_stride, true);
|
||||
resize(C, m_, n_, CUBLAS_OP_N, ldc_, true);
|
||||
resize(D, m_, n_, CUBLAS_OP_N, ldd_, true);
|
||||
resize(ref_A, m_, n_, ReductionSize, CUBLAS_OP_N, lda_, reduction_stride, false);
|
||||
resize(ref_C, m_, n_, CUBLAS_OP_N, ldc_, false);
|
||||
resize(ref_D, m_, n_, CUBLAS_OP_N, ldd_, false);
|
||||
}
|
||||
|
||||
/// Dtor
|
||||
~BatchedReductionTestbed() { }
|
||||
|
||||
/// Getters
|
||||
/// Returns a pointer to the C operand
|
||||
typename HostTensorA::DeviceType* ptr_A() const { return A.device_data(); }
|
||||
/// Returns a pointer to the C operand
|
||||
typename HostMatrixC::DeviceType* ptr_C() const { return C.device_data(); }
|
||||
/// Returns a pointer to the D operand
|
||||
typename HostMatrixD::DeviceType* ptr_D() const { return D.device_data(); }
|
||||
|
||||
///
|
||||
int M() const { return m; }
|
||||
///
|
||||
int N() const { return n; }
|
||||
///
|
||||
int get_lda() const { return lda; }
|
||||
///
|
||||
int get_ldc() const { return ldc; }
|
||||
///
|
||||
int get_ldd() const { return ldd; }
|
||||
///
|
||||
ScalarAlpha get_alpha() const { return alpha; }
|
||||
///
|
||||
ScalarBeta get_beta() const { return beta; }
|
||||
///
|
||||
long long int get_reduction_stride() const { return reduction_stride; }
|
||||
|
||||
/// Initializes data, randomly
|
||||
void initialize(int seed = -1) {
|
||||
A.fill_random(RandomGenerator<AType>(seed + 7));
|
||||
//A.fill(3);
|
||||
C.fill_random(RandomGenerator<CType>(seed));
|
||||
//C.fill(1);
|
||||
D.fill_random(RandomGenerator<DType>(seed + 11));
|
||||
//D.fill(2);
|
||||
}
|
||||
|
||||
/// compute_host
|
||||
void compute_host() {
|
||||
ref_A.fill(A);
|
||||
ref_C.fill(C);
|
||||
ref_D.fill(D);
|
||||
/// D = alpha * reduction(A) + beta * C
|
||||
|
||||
for (int m_idx = 0; m_idx < m; m_idx++) {
|
||||
for (int n_idx = 0; n_idx < n; n_idx++) {
|
||||
ScalarAccum accum = static_cast<ScalarAccum>(0.0);
|
||||
for (int r_idx = 0; r_idx < static_cast<int>(ReductionSize); r_idx++) {
|
||||
accum += static_cast<ScalarAccum>(ref_A.at(cutlass::make_Coord(r_idx, n_idx, m_idx)));
|
||||
}
|
||||
ref_D.at(cutlass::make_Coord(m_idx, n_idx)) = static_cast<DType>(
|
||||
alpha * static_cast<ScalarAlpha>(accum) +
|
||||
beta * static_cast<ScalarBeta>(ref_C.at(cutlass::make_Coord(m_idx, n_idx)))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Verifies the contents of C equal the host-side reference
|
||||
bool verify_with_host() {
|
||||
compute_host();
|
||||
D.sync_host();
|
||||
bool passed = D.bit_equals(ref_D);
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
|
||||
} //namespace test
|
||||
161
tools/test/unit/reduction/mixed_batched_reduction.cu
Normal file
161
tools/test/unit/reduction/mixed_batched_reduction.cu
Normal file
@ -0,0 +1,161 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/shape.h"
|
||||
#include "tools/util/host_tensor.h"
|
||||
#include "cutlass/reduction/batched_reduction.h"
|
||||
#include "cutlass/reduction/batched_reduction_traits.h"
|
||||
#include "tools/test/unit/reduction/test_batched_reduction.h"
|
||||
#include "tools/test/unit/reduction/batched_reduction_testbed.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Float_batched_reduction_half_alphabeta_float, batched_reduction_128x256x16) {
|
||||
/*
|
||||
The output matrix is 128x256
|
||||
The input matrix is 128x256x16
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
A is float, Accumulation is float
|
||||
alpha and beta are float
|
||||
C and D are half
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 16;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*A*/
|
||||
half, /*C*/
|
||||
half, /*D*/
|
||||
float, /*alpha and beta*/
|
||||
float, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits_16;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_16>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Float_batched_reduction_half_alphabeta_half, batched_reduction_128x256x16) {
|
||||
/*
|
||||
The output matrix is 128x256
|
||||
The input matrix is 128x256x16
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
A is float, Accumulation is float
|
||||
alpha and beta are float
|
||||
C and D are half
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 256;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 16;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*A*/
|
||||
half, /*C*/
|
||||
half, /*D*/
|
||||
half, /*alpha and beta*/
|
||||
float, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits_16;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_16>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Float_batched_reduction_half_alphabeta_float, batched_reduction_128x64x80) {
|
||||
/*
|
||||
The output matrix is 128x64
|
||||
The input matrix is 128x64x80
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 64;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 80;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*A*/
|
||||
half, /*C*/
|
||||
half, /*D*/
|
||||
float, /*alpha and beta*/
|
||||
float, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits_80;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_80>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Float_batched_reduction_half_alphabeta_half, batched_reduction_128x64x80) {
|
||||
/*
|
||||
The output matrix is 128x64
|
||||
The input matrix is 128x64x80
|
||||
The reduction will be applied at the third dim of input matrix
|
||||
*/
|
||||
|
||||
|
||||
const int m = 128;
|
||||
const int n = 64;
|
||||
const int lda = 128;
|
||||
const int ldc = 128;
|
||||
const int ldd = 128;
|
||||
const int reduction_size = 80;
|
||||
typedef cutlass::reduction::BatchedReductionTraits<float, /*A*/
|
||||
half, /*C*/
|
||||
half, /*D*/
|
||||
half, /*alpha and beta*/
|
||||
float, /*accumulation type*/
|
||||
reduction_size,
|
||||
cutlass::Shape<1, 1, 128>,
|
||||
cutlass::Shape<1, 1, 64>,
|
||||
cutlass::Shape<1, 1, 2> >
|
||||
BatchedReductionTraits_80;
|
||||
|
||||
test_batched_reduction<BatchedReductionTraits_80>(m, n, lda, ldc, ldd);
|
||||
|
||||
}
|
||||
73
tools/test/unit/reduction/test_batched_reduction.h
Normal file
73
tools/test/unit/reduction/test_batched_reduction.h
Normal file
@ -0,0 +1,73 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Test environment for batched reduction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "tools/test/unit/reduction/batched_reduction_testbed.h"
|
||||
|
||||
template <typename BatchedRecutionTraits_>
|
||||
static void test_batched_reduction(
|
||||
int m,
|
||||
int n,
|
||||
int lda,
|
||||
int ldc,
|
||||
int ldd) {
|
||||
typedef BatchedRecutionTraits_ Traits;
|
||||
typedef cutlass::reduction::BatchedReduction<Traits> batched_reduction;
|
||||
typename batched_reduction::Params params;
|
||||
|
||||
test::BatchedReductionTestbed<typename cutlass::TypeTraits<typename Traits::ScalarA>::host_type,
|
||||
typename cutlass::TypeTraits<typename Traits::ScalarC>::host_type,
|
||||
typename cutlass::TypeTraits<typename Traits::ScalarD>::host_type,
|
||||
typename cutlass::TypeTraits<typename Traits::ScalarAlphaBeta>::host_type,
|
||||
typename cutlass::TypeTraits<typename Traits::ScalarAlphaBeta>::host_type,
|
||||
typename cutlass::TypeTraits<typename Traits::ScalarAccum>::host_type,
|
||||
Traits::ReductionSize>
|
||||
testbed(m, n, lda, ldc, ldd);
|
||||
testbed.initialize();
|
||||
|
||||
params.initialize(testbed.M(),
|
||||
testbed.N(),
|
||||
testbed.get_alpha(),
|
||||
testbed.get_beta(),
|
||||
testbed.get_reduction_stride(),
|
||||
testbed.ptr_A(),
|
||||
testbed.get_lda(),
|
||||
testbed.ptr_C(),
|
||||
testbed.get_ldc(),
|
||||
testbed.ptr_D(),
|
||||
testbed.get_ldd());
|
||||
|
||||
|
||||
batched_reduction::launch(params);
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
ASSERT_EQ(result, cudaSuccess) << "\nCUDA kernel launch error: " << cudaGetErrorString(result)
|
||||
<< "\n";
|
||||
|
||||
ASSERT_TRUE(testbed.verify_with_host());
|
||||
}
|
||||
125
tools/test/unit/tile_iterator_test.cu
Normal file
125
tools/test/unit/tile_iterator_test.cu
Normal file
@ -0,0 +1,125 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
#include "cutlass/shape.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using ::cutlass::Coord;
|
||||
using ::cutlass::Fragment;
|
||||
using ::cutlass::IteratorAdvance;
|
||||
using ::cutlass::make_Coord;
|
||||
using ::cutlass::MemorySpace;
|
||||
using ::cutlass::Shape;
|
||||
using ::cutlass::TileLoadIterator;
|
||||
using ::cutlass::TileTraits;
|
||||
using ::testing::Test;
|
||||
|
||||
|
||||
// TODO: Move the following to standard test helper infrastructure
|
||||
// Returns randomly initialized array
|
||||
//
|
||||
// Caller is responsible for deallocation.
|
||||
float* malloc_randomly_initialized_array(int elements) {
|
||||
float* matrix = (float*)calloc(sizeof(float), elements);
|
||||
for (int i = 0; i < elements; i++) {
|
||||
matrix[i] = float((rand() - RAND_MAX/2) % 10);
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
|
||||
#define kWarpSize 32
|
||||
#define kCtaWarpCnt 6
|
||||
#define kDimXPerWarp 16
|
||||
#define kDimYPerWarp 2
|
||||
#define kWarpTileWidth kDimXPerWarp
|
||||
#define kDimYPerThread (kWarpSize / kDimYPerWarp)
|
||||
#define kDimX 2400
|
||||
#define kDimY 800
|
||||
|
||||
struct TileThreadOffset {
|
||||
public:
|
||||
TileThreadOffset() : xidx(0), yidx(0) {}
|
||||
TileThreadOffset(int x, int y) : xidx(x), yidx(y) {}
|
||||
|
||||
__host__ __device__ Coord<4> operator()() const {
|
||||
int column = (yidx / kDimYPerWarp) * kDimXPerWarp +
|
||||
(yidx & (kDimYPerWarp - 1)) * kDimYPerThread;
|
||||
return make_Coord(0, column, xidx, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
int xidx, yidx;
|
||||
};
|
||||
|
||||
|
||||
TEST(TileIteratorTest, BasicCpuSideIterateTile) {
|
||||
// Basic test demonstrating CPU-side tile iteration mimicking a 16x16 tile load/warp with 6 warp
|
||||
// CTAs iterating over the Y.
|
||||
|
||||
float* matrix = malloc_randomly_initialized_array(kDimX*kDimY);
|
||||
|
||||
typedef Shape</*kD=*/1, /*kH=*/kCtaWarpCnt * kDimXPerWarp, /*kW=*/kDimXPerWarp> TileShape;
|
||||
typedef TileLoadIterator<
|
||||
TileTraits<TileShape,
|
||||
/* Delta = */ Shape</*kD=*/1, /*kH=*/1, /*kW=*/1>,
|
||||
/* Iter = */ Shape</*kD=*/1, /*kH=*/kDimYPerThread, /*kW=*/1>,
|
||||
TileThreadOffset, /*AccessSize=*/1>,
|
||||
float, IteratorAdvance::kH, MemorySpace::kGlobal> GlobalTileLoader;
|
||||
typedef GlobalTileLoader::Fragment BufferType;
|
||||
//
|
||||
// TODO: The following loop should probably be refactored out into standard test helper code for
|
||||
// tile iteration.
|
||||
//
|
||||
// Iterate: gridDim(1, 1, kDimX / kDimXPerWarp), blockDim(1, kDimXPerWarp, kDimYPerWarp)
|
||||
for (int blockIdx_x = 0; blockIdx_x < kDimX / kDimXPerWarp; blockIdx_x++) {
|
||||
for (int threadIdx_x = 0; threadIdx_x < kDimXPerWarp; threadIdx_x++) {
|
||||
for (int threadIdx_y = 0; threadIdx_y < kCtaWarpCnt * kDimYPerWarp; threadIdx_y++) {
|
||||
GlobalTileLoader loader(
|
||||
GlobalTileLoader::Params(matrix,
|
||||
/* stride_d=*/1, /*stride_h=*/kDimX, /*stride_w=*/1),
|
||||
make_Coord(/*d=*/0, /*h=*/0, /*w=*/blockIdx_x * kDimXPerWarp),
|
||||
TileThreadOffset(threadIdx_x, threadIdx_y));
|
||||
BufferType b;
|
||||
for (int yidx = 0; (yidx + threadIdx_y * kWarpTileWidth) < kDimY;
|
||||
yidx += kCtaWarpCnt*kWarpTileWidth) {
|
||||
|
||||
loader.load_post_increment(b);
|
||||
for (int i = 0; i < BufferType::kElements; i++) {
|
||||
int matrix_idx = blockIdx_x * kDimXPerWarp + threadIdx_x + // row offset
|
||||
kDimX * ((threadIdx_y & (kDimYPerWarp - 1)) * kDimYPerThread +
|
||||
(threadIdx_y / kDimYPerWarp) * kWarpTileWidth + i + yidx);
|
||||
ASSERT_EQ(b[i], matrix[matrix_idx])
|
||||
<< "blockIdx.x = " << blockIdx_x << " threadIdx.x = " << threadIdx_x
|
||||
<< " threadIdx.y = " << threadIdx_y << " yidx = " << yidx
|
||||
<< " tile_idx = " << i << " matrix_idx = " << matrix_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
free(matrix);
|
||||
}
|
||||
127
tools/util/reference/detail/inner_product.h
Normal file
127
tools/util/reference/detail/inner_product.h
Normal file
@ -0,0 +1,127 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GEMM in host-side code.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace detail {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template function to compute an inner product.
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a
|
||||
// host-only type
|
||||
template <typename Atype, typename Btype, typename Ctype>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Ctype inner_product(Atype a, Btype b, Ctype c) {
|
||||
return Ctype(a) * Ctype(b) + c;
|
||||
}
|
||||
|
||||
/// Specialization for matrix multiplication with binary operands
|
||||
template <>
|
||||
CUTLASS_HOST_DEVICE
|
||||
int inner_product<Vector<bin1_t, 32>, Vector<bin1_t, 32>, int>(
|
||||
Vector<bin1_t, 32> a,
|
||||
Vector<bin1_t, 32> b,
|
||||
int c) {
|
||||
|
||||
int accum = 0;
|
||||
for (int bit = 0; bit < 32; bit++) {
|
||||
accum += a[bit] ^ b[bit];
|
||||
}
|
||||
return accum + c;
|
||||
}
|
||||
|
||||
/// Specialization for matrix multiplication with signed 4-bit integer operands
|
||||
template <>
|
||||
CUTLASS_HOST_DEVICE
|
||||
int inner_product<Vector<int4_t, 8>, Vector<int4_t, 8>, int>(
|
||||
Vector<int4_t, 8> a,
|
||||
Vector<int4_t, 8> b,
|
||||
int c) {
|
||||
|
||||
int accum = 0;
|
||||
for (int k = 0; k < 8; k++) {
|
||||
accum += a[k] * b[k];
|
||||
}
|
||||
return accum + c;
|
||||
}
|
||||
|
||||
/// Specialization for matrix multiplication with unsigned 4-bit integer operands
|
||||
template <>
|
||||
CUTLASS_HOST_DEVICE
|
||||
int inner_product<Vector<uint4_t, 8>, Vector<uint4_t, 8>, int>(
|
||||
Vector<uint4_t, 8> a,
|
||||
Vector<uint4_t, 8> b,
|
||||
int c) {
|
||||
|
||||
int accum = 0;
|
||||
for (int k = 0; k < 8; k++) {
|
||||
accum += a[k] * b[k];
|
||||
}
|
||||
return accum + c;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename SrcType, typename DstType>
|
||||
struct Cast {
|
||||
// Default behavior: convert to the destination type
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
static DstType apply(SrcType src) { return static_cast<DstType>(src); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Cast<float, int8_t> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int8_t apply(float src) {
|
||||
// Clamp to the range of signed 8-bit integers.
|
||||
return static_cast<int8_t>(fmaxf(-128.f, fminf(127.f, src)));
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Cast<float, uint8_t> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static uint8_t apply(float src) {
|
||||
// Clamp to the range of signed 8-bit integers.
|
||||
return static_cast<uint8_t>(fmaxf(0.f, fminf(255.f, src)));
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace detail
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
|
||||
224
tools/util/reference/device/gemm.h
Normal file
224
tools/util/reference/device/gemm.h
Normal file
@ -0,0 +1,224 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GEMM in device-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
|
||||
#include "tools/util/reference/device/kernel/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
||||
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
||||
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
||||
/// arguments explicitly.
|
||||
template <
|
||||
typename TensorRefA,
|
||||
typename TensorRefB,
|
||||
typename TensorRefC,
|
||||
typename ScalarType,
|
||||
typename AccumulatorType
|
||||
>
|
||||
void Gemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRefA tensor_a,
|
||||
TensorRefB tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRefC tensor_c,
|
||||
AccumulatorType initial_accum) {
|
||||
|
||||
typedef typename TensorRefA::Storage AType;
|
||||
typedef typename TensorRefB::Storage BType;
|
||||
typedef typename TensorRefC::Storage CType;
|
||||
|
||||
static_assert(
|
||||
TensorRefA::kRank == 2 &&
|
||||
TensorRefB::kRank == 2 &&
|
||||
TensorRefC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
// Blocking structure potentially improves performance of reference implementation
|
||||
// with a minor increase in complexity.
|
||||
//
|
||||
// Note, this reference implementation is NOT expected to approach peak performance.
|
||||
typedef Shape<1, 4, 4> OutputTile;
|
||||
|
||||
dim3 block(16, 8);
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * OutputTile::kW - 1) / (block.x * OutputTile::kW),
|
||||
(problem_size.n() + block.y * OutputTile::kH - 1) / (block.y * OutputTile::kH)
|
||||
);
|
||||
|
||||
// Launch a GEMM kernel
|
||||
kernel::Gemm<
|
||||
TensorRefA,
|
||||
TensorRefB,
|
||||
TensorRefC,
|
||||
ScalarType,
|
||||
AccumulatorType,
|
||||
OutputTile
|
||||
><<< grid, block >>>(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
initial_accum
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
/// This assumes the accumulator type is the same type as the scalars.
|
||||
template <
|
||||
typename TensorRefA,
|
||||
typename TensorRefB,
|
||||
typename TensorRefC,
|
||||
typename ScalarType
|
||||
>
|
||||
void Gemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRefA tensor_a,
|
||||
TensorRefB tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRefC tensor_c) {
|
||||
|
||||
Gemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Batched GEMM
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a batch of GEMMs over a set of matrices of common dimension.
|
||||
//
|
||||
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
||||
//
|
||||
template <
|
||||
typename TensorRefCollectionA,
|
||||
typename TensorRefCollectionB,
|
||||
typename TensorRefCollectionC,
|
||||
typename ScalarType,
|
||||
typename AccumulatorType
|
||||
>
|
||||
void BatchedGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRefCollectionA tensor_a,
|
||||
TensorRefCollectionB tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRefCollectionC tensor_c,
|
||||
AccumulatorType initial_accum) {
|
||||
|
||||
typedef typename TensorRefCollectionA::Storage AType;
|
||||
typedef typename TensorRefCollectionB::Storage BType;
|
||||
typedef typename TensorRefCollectionC::Storage CType;
|
||||
|
||||
static_assert(
|
||||
TensorRefCollectionA::kRank == 2 &&
|
||||
TensorRefCollectionB::kRank == 2 &&
|
||||
TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2");
|
||||
|
||||
// Blocking structure potentially improves performance of reference implementation
|
||||
// with a minor increase in complexity.
|
||||
//
|
||||
// Note, this reference implementation is NOT expected to approach peak performance.
|
||||
typedef Shape<1, 4, 4> OutputTile;
|
||||
|
||||
dim3 block(16, 8);
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * OutputTile::kW - 1) / (block.x * OutputTile::kW),
|
||||
(problem_size.n() + block.y * OutputTile::kH - 1) / (block.y * OutputTile::kH),
|
||||
problem_size.batch()
|
||||
);
|
||||
|
||||
// Launch a GEMM kernel
|
||||
kernel::BatchedGemm<
|
||||
TensorRefCollectionA,
|
||||
TensorRefCollectionB,
|
||||
TensorRefCollectionC,
|
||||
ScalarType,
|
||||
AccumulatorType,
|
||||
OutputTile
|
||||
><<< grid, block >>>(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
initial_accum
|
||||
);
|
||||
}
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
//
|
||||
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
||||
//
|
||||
template <
|
||||
typename TensorRefCollectionA,
|
||||
typename TensorRefCollectionB,
|
||||
typename TensorRefCollectionC,
|
||||
typename ScalarType,
|
||||
typename AccumulatorType
|
||||
>
|
||||
void BatchedGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRefCollectionA tensor_a,
|
||||
TensorRefCollectionB tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRefCollectionC tensor_c) {
|
||||
|
||||
BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
148
tools/util/reference/device/kernel/gemm.h
Normal file
148
tools/util/reference/device/kernel/gemm.h
Normal file
@ -0,0 +1,148 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GEMM in host-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
|
||||
#include "tools/util/reference/device/thread/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace device {
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
template <
|
||||
typename TensorRefA,
|
||||
typename TensorRefB,
|
||||
typename TensorRefC,
|
||||
typename ScalarType,
|
||||
typename AccumulatorType,
|
||||
typename OutputTile
|
||||
>
|
||||
__global__ void Gemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRefA tensor_a,
|
||||
TensorRefB tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRefC tensor_c,
|
||||
AccumulatorType initial_accum) {
|
||||
|
||||
// Map each thread to a unique tile of the output matrix
|
||||
MatrixCoord output_coord(
|
||||
(threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kW,
|
||||
(threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kH
|
||||
);
|
||||
|
||||
// Compute the general matrix product
|
||||
thread::Gemm<
|
||||
TensorRefA,
|
||||
TensorRefB,
|
||||
TensorRefC,
|
||||
ScalarType,
|
||||
AccumulatorType,
|
||||
OutputTile
|
||||
> gemm(initial_accum);
|
||||
|
||||
gemm.multiply_add(
|
||||
problem_size,
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
output_coord);
|
||||
|
||||
gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
template <
|
||||
typename TensorRefCollectionA,
|
||||
typename TensorRefCollectionB,
|
||||
typename TensorRefCollectionC,
|
||||
typename ScalarType,
|
||||
typename AccumulatorType,
|
||||
typename OutputTile
|
||||
>
|
||||
__global__ void BatchedGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRefCollectionA tensor_collection_a,
|
||||
TensorRefCollectionB tensor_collection_b,
|
||||
ScalarType beta,
|
||||
TensorRefCollectionC tensor_collection_c,
|
||||
AccumulatorType initial_accum) {
|
||||
|
||||
// Obtain batch ID
|
||||
int batch_id = blockIdx.z;
|
||||
|
||||
// Dereference based on batch_id
|
||||
typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id);
|
||||
typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id);
|
||||
typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id);
|
||||
|
||||
// Map each thread to a unique tile of the output matrix
|
||||
MatrixCoord output_coord(
|
||||
(threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kW,
|
||||
(threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kH
|
||||
);
|
||||
|
||||
// Compute the general matrix product
|
||||
thread::Gemm<
|
||||
typename TensorRefCollectionA::TensorRef,
|
||||
typename TensorRefCollectionB::TensorRef,
|
||||
typename TensorRefCollectionC::TensorRef,
|
||||
ScalarType,
|
||||
AccumulatorType,
|
||||
OutputTile
|
||||
> gemm(initial_accum);
|
||||
|
||||
gemm.multiply_add(
|
||||
problem_size,
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
output_coord);
|
||||
|
||||
gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace device
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
95
tools/util/reference/device/kernel/split_complex_gemm.h
Normal file
95
tools/util/reference/device/kernel/split_complex_gemm.h
Normal file
@ -0,0 +1,95 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GEMM in host-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
#include "cutlass/util/complex.h"
|
||||
|
||||
#include "tools/util/reference/device/thread/split_complex_gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace device {
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
template <
|
||||
typename TensorRefA, /// concept: ZipTensorRef
|
||||
typename TensorRefB, /// concept: ZipTensorRef
|
||||
typename TensorRefC, /// concept: ZipTensorRef
|
||||
typename ScalarType, /// real-valued type underlying complex scalars
|
||||
typename AccumulatorType, /// real-valued type underlying complex accumulators
|
||||
typename OutputTile /// concept: Shape
|
||||
>
|
||||
__global__ void SplitComplexGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
platform::complex<ScalarType> alpha,
|
||||
TensorRefA tensor_a,
|
||||
TensorRefB tensor_b,
|
||||
platform::complex<ScalarType> beta,
|
||||
TensorRefC tensor_c,
|
||||
platform::complex<AccumulatorType> initial_accum) {
|
||||
|
||||
// Map each thread to a unique tile of the output matrix
|
||||
MatrixCoord output_coord(
|
||||
(threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kW,
|
||||
(threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kH
|
||||
);
|
||||
|
||||
// Compute the general matrix product
|
||||
thread::Gemm<
|
||||
TensorRefA,
|
||||
TensorRefB,
|
||||
TensorRefC,
|
||||
ScalarType,
|
||||
AccumulatorType,
|
||||
OutputTile
|
||||
> gemm(initial_accum);
|
||||
|
||||
gemm.multiply_add(
|
||||
problem_size,
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
output_coord);
|
||||
|
||||
gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace device
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
103
tools/util/reference/device/split_complex_gemm.h
Normal file
103
tools/util/reference/device/split_complex_gemm.h
Normal file
@ -0,0 +1,103 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GEMM in device-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
#include "cutlass/util/complex.h"
|
||||
|
||||
#include "tools/util/reference/device/kernel/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a complex-valued GEMM whose operands are in the split-complex format.
|
||||
template <
|
||||
typename TensorRefA, /// concept: ZipTensorRef
|
||||
typename TensorRefB, /// concept: ZipTensorRef
|
||||
typename TensorRefC, /// concept: ZipTensorRef
|
||||
typename ScalarType, /// real-valued type underlying complex scalars
|
||||
typename AccumulatorType /// real-valued type underlying complex accumulators
|
||||
>
|
||||
void SplitComplexGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
platform::complex<ScalarType> alpha,
|
||||
TensorRefA tensor_a,
|
||||
TensorRefB tensor_b,
|
||||
platform::complex<ScalarType> beta,
|
||||
TensorRefC tensor_c,
|
||||
platform::complex<ScalarType> initial_accum) {
|
||||
|
||||
static_assert(
|
||||
TensorRefA::First::kRank == 2 && TensorRefA::Second::kRank == 2 &&
|
||||
TensorRefB::First::kRank == 2 && TensorRefB::Second::kRank == 2 &&
|
||||
TensorRefC::First::kRank == 2 && TensorRefC::Second::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
// Blocking structure potentially improves performance of reference implementation
|
||||
// with a minor increase in complexity.
|
||||
//
|
||||
// Note, this reference implementation is NOT expected to approach peak performance.
|
||||
typedef Shape<1, 4, 4> OutputTile;
|
||||
|
||||
dim3 block(16, 8);
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * OutputTile::kW - 1) / (block.x * OutputTile::kW),
|
||||
(problem_size.n() + block.y * OutputTile::kH - 1) / (block.y * OutputTile::kH)
|
||||
);
|
||||
|
||||
// Launch a GEMM kernel
|
||||
kernel::SplitComplexGemm<
|
||||
TensorRefA,
|
||||
TensorRefB,
|
||||
TensorRefC,
|
||||
ScalarType,
|
||||
AccumulatorType,
|
||||
OutputTile
|
||||
><<< grid, block >>>(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
initial_accum
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
176
tools/util/reference/device/thread/gemm.h
Normal file
176
tools/util/reference/device/thread/gemm.h
Normal file
@ -0,0 +1,176 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GEMM in host-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
|
||||
#include "tools/util/reference/detail/inner_product.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace device {
|
||||
namespace thread {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Thread-level blocked general matrix product.
|
||||
//
|
||||
// Note, this is a reference implementation. Performance is not expected to approach peak.
|
||||
//
|
||||
template <
|
||||
typename TensorRefA,
|
||||
typename TensorRefB,
|
||||
typename TensorRefC,
|
||||
typename ScalarType,
|
||||
typename AccumulatorType,
|
||||
typename OutputTile
|
||||
>
|
||||
struct Gemm {
|
||||
|
||||
typedef typename TensorRefA::Storage ScalarA;
|
||||
typedef typename TensorRefB::Storage ScalarB;
|
||||
typedef typename TensorRefC::Storage ScalarC;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Tile for A operand
|
||||
ScalarA A_tile[OutputTile::kW];
|
||||
|
||||
/// Tile for B operand
|
||||
ScalarB B_tile[OutputTile::kH];
|
||||
|
||||
/// Tile for Accumulator
|
||||
AccumulatorType accum[OutputTile::kH][OutputTile::kW];
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Gemm(AccumulatorType initial_accum = AccumulatorType(0)) {
|
||||
|
||||
// Clear fetch registers
|
||||
for (int i = 0; i < OutputTile::kW; ++i) {
|
||||
A_tile[i] = ScalarA(0);
|
||||
}
|
||||
|
||||
for (int j = 0; j < OutputTile::kW; ++j) {
|
||||
B_tile[j] = ScalarB(0);
|
||||
}
|
||||
|
||||
// Clear accumulators
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < OutputTile::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < OutputTile::kW; ++i) {
|
||||
accum[j][i] = initial_accum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a matrix product
|
||||
CUTLASS_HOST_DEVICE
|
||||
Gemm & multiply_add(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRefA tensor_a,
|
||||
TensorRefB tensor_b,
|
||||
MatrixCoord output_coord = MatrixCoord()) {
|
||||
|
||||
// Loop over the GEMM K dimension
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int k = 0; k < problem_size.k(); ++k) {
|
||||
|
||||
// Fetch a slice of the A matrix
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < OutputTile::kW; ++i) {
|
||||
if (output_coord.row() + i < problem_size.m()) {
|
||||
A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k));
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch a slice of the B matrix
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < OutputTile::kH; ++j) {
|
||||
if (output_coord.column() + j < problem_size.n()) {
|
||||
B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j));
|
||||
}
|
||||
}
|
||||
|
||||
// Compute an accumulated matrix product
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < OutputTile::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < OutputTile::kW; ++i) {
|
||||
accum[j][i] = detail::inner_product(A_tile[i], B_tile[j], accum[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Performs linear scaling of matrix product and updates output tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Gemm & epilogue(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
ScalarType beta,
|
||||
TensorRefC tensor_c,
|
||||
MatrixCoord output_coord = MatrixCoord()) {
|
||||
|
||||
// Update the output tensor
|
||||
for (int j = 0; j < OutputTile::kH; ++j) {
|
||||
for (int i = 0; i < OutputTile::kW; ++i) {
|
||||
MatrixCoord coord = output_coord + MatrixCoord(i, j);
|
||||
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
|
||||
|
||||
tensor_c.at(coord) = detail::Cast<ScalarType, ScalarC>::apply(
|
||||
alpha * ScalarType(accum[j][i]) +
|
||||
beta * ScalarType(tensor_c.at(coord))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace device
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
192
tools/util/reference/device/thread/split_complex_gemm.h
Normal file
192
tools/util/reference/device/thread/split_complex_gemm.h
Normal file
@ -0,0 +1,192 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GEMM in host-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
|
||||
#include "tools/util/reference/detail/inner_product.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace device {
|
||||
namespace thread {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Thread-level blocked general matrix product.
|
||||
//
|
||||
// Note, this is a reference implementation. Performance is not expected to approach peak.
|
||||
//
|
||||
template <
|
||||
typename TensorRefA, /// concept: ZipTensorRef
|
||||
typename TensorRefB, /// concept: ZipTensorRef
|
||||
typename TensorRefC, /// concept: ZipTensorRef
|
||||
typename ScalarType, /// real-valued type underlying complex scalars
|
||||
typename AccumulatorType, /// real-valued type underlying complex accumulators
|
||||
typename OutputTile /// concept: Shape
|
||||
>
|
||||
struct SplitComplexGemm {
|
||||
|
||||
typedef typename TensorRefA::First::Storage RealScalarA;
|
||||
typedef typename TensorRefB::First::Storage RealScalarB;
|
||||
typedef typename TensorRefC::First::Storage RealScalarC;
|
||||
|
||||
typedef platform::complex<RealScalarA> ScalarA;
|
||||
typedef platform::complex<RealScalarB> ScalarB;
|
||||
typedef platform::complex<AccumulatorType> ComplexAccumulator;
|
||||
typedef platform::complex<ScalarType> ComplexScalar;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Tile for A operand
|
||||
ScalarA A_tile[OutputTile::kW];
|
||||
|
||||
/// Tile for B operand
|
||||
ScalarB B_tile[OutputTile::kH];
|
||||
|
||||
/// Tile for Accumulator
|
||||
ComplexAccumulator accum[OutputTile::kH][OutputTile::kW];
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Gemm(ComplexAccumulator initial_accum = AccumulatorType(0)) {
|
||||
|
||||
// Clear fetch registers
|
||||
for (int i = 0; i < OutputTile::kW; ++i) {
|
||||
A_tile[i] = ScalarA(0);
|
||||
}
|
||||
|
||||
for (int j = 0; j < OutputTile::kW; ++j) {
|
||||
B_tile[j] = ScalarB(0);
|
||||
}
|
||||
|
||||
// Clear accumulators
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < OutputTile::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < OutputTile::kW; ++i) {
|
||||
accum[j][i] = initial_accum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a matrix product
|
||||
CUTLASS_HOST_DEVICE
|
||||
Gemm & multiply_add(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRefA tensor_a,
|
||||
TensorRefB tensor_b,
|
||||
MatrixCoord output_coord = MatrixCoord()) {
|
||||
|
||||
// Loop over the GEMM K dimension
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int k = 0; k < problem_size.k(); ++k) {
|
||||
|
||||
// Fetch a slice of the A matrix - zip into complex values
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < OutputTile::kW; ++i) {
|
||||
if (output_coord.row() + i < problem_size.m()) {
|
||||
MatrixCoord coord(output_coord.row() + i, k);
|
||||
A_tile[i].real() = tensor_a.first.at(coord);
|
||||
A_tile[i].imag() = tensor_a.second.at(coord);
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch a slice of the B matrix - zip into complex values
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < OutputTile::kH; ++j) {
|
||||
if (output_coord.column() + j < problem_size.n()) {
|
||||
MatrixCoord coord(k, output_coord.column() + j);
|
||||
B_tile[j].real() = tensor_b.first.at(coord);
|
||||
B_tile[j].imag() = tensor_b.second.at(coord);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute an accumulated matrix product on complex values
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < OutputTile::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < OutputTile::kW; ++i) {
|
||||
accum[j][i] = detail::inner_product(A_tile[i], B_tile[j], accum[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Performs linear scaling of matrix product and updates output tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Gemm & epilogue(
|
||||
gemm::GemmCoord problem_size,
|
||||
ComplexScalar alpha,
|
||||
ComplexScalar beta,
|
||||
TensorRefC tensor_c,
|
||||
MatrixCoord output_coord = MatrixCoord()) {
|
||||
|
||||
// Update the output tensor
|
||||
for (int j = 0; j < OutputTile::kH; ++j) {
|
||||
for (int i = 0; i < OutputTile::kW; ++i) {
|
||||
MatrixCoord coord = output_coord + MatrixCoord(i, j);
|
||||
if (coord < problem_size.mn()) {
|
||||
|
||||
ComplexScalar source(
|
||||
tensor_c.first.at(coord),
|
||||
tensor_c.second.at(coord)
|
||||
);
|
||||
|
||||
// Final calculation is performed in data type of scalars
|
||||
ComplexScalar result = alpha * ComplexScalar(accum[j][i].real(), accum[j][i].imag()) + beta * source;
|
||||
|
||||
// Unzip and convert into output tensor data type
|
||||
tensor_c.first.at(coord) = detail::Cast<ScalarType, RealScalarC>::apply(result.real());
|
||||
tensor_c.second.at(coord) = detail::Cast<ScalarType, RealScalarC>::apply(result.imag());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace device
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
@ -33,90 +33,14 @@
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
|
||||
#include "tools/util/reference/detail/inner_product.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace host {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Template function to compute an inner product.
|
||||
template <typename Atype, typename Btype, typename Ctype>
|
||||
Ctype inner_product(Atype a, Btype b, Ctype c) {
|
||||
return Ctype(a) * Ctype(b) + c;
|
||||
}
|
||||
|
||||
/// Specialization for matrix multiplication with binary operands
|
||||
template <>
|
||||
inline int inner_product<Vector<bin1_t, 32>, Vector<bin1_t, 32>, int>(
|
||||
Vector<bin1_t, 32> a,
|
||||
Vector<bin1_t, 32> b,
|
||||
int c) {
|
||||
|
||||
int accum = 0;
|
||||
for (int bit = 0; bit < 32; bit++) {
|
||||
accum += a[bit] ^ b[bit];
|
||||
}
|
||||
return accum + c;
|
||||
}
|
||||
|
||||
/// Specialization for matrix multiplication with signed 4-bit integer operands
|
||||
template <> inline
|
||||
int inner_product<Vector<int4_t, 8>, Vector<int4_t, 8>, int>(
|
||||
Vector<int4_t, 8> a,
|
||||
Vector<int4_t, 8> b,
|
||||
int c) {
|
||||
|
||||
int accum = 0;
|
||||
for (int k = 0; k < 8; k++) {
|
||||
accum += a[k] * b[k];
|
||||
}
|
||||
return accum + c;
|
||||
}
|
||||
|
||||
/// Specialization for matrix multiplication with unsigned 4-bit integer operands
|
||||
template <> inline
|
||||
int inner_product<Vector<uint4_t, 8>, Vector<uint4_t, 8>, int>(
|
||||
Vector<uint4_t, 8> a,
|
||||
Vector<uint4_t, 8> b,
|
||||
int c) {
|
||||
|
||||
int accum = 0;
|
||||
for (int k = 0; k < 8; k++) {
|
||||
accum += a[k] * b[k];
|
||||
}
|
||||
return accum + c;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename SrcType, typename DstType>
|
||||
struct Cast {
|
||||
// Default behavior: convert to the destination type
|
||||
static inline DstType apply(SrcType src) { return static_cast<DstType>(src); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Cast<float, int8_t> {
|
||||
static inline int8_t apply(float src) {
|
||||
// Clamp to the range of signed 8-bit integers.
|
||||
return static_cast<int8_t>(fmaxf(-128.f, fminf(127.f, src)));
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Cast<float, uint8_t> {
|
||||
static inline uint8_t apply(float src) {
|
||||
// Clamp to the range of signed 8-bit integers.
|
||||
return static_cast<uint8_t>(fmaxf(0.f, fminf(255.f, src)));
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
///
|
||||
@ -178,7 +102,7 @@ void Gemm(
|
||||
AType a = tensor_a.at(MatrixCoord(row, k_block));
|
||||
BType b = tensor_b.at(MatrixCoord(k_block, col));
|
||||
|
||||
accum[i][j] = detail::inner_product(a, b, accum[i][j]);
|
||||
accum[i][j] = cutlass::reference::detail::inner_product(a, b, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -192,7 +116,7 @@ void Gemm(
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
if (row < M && col < N) {
|
||||
|
||||
tensor_c.at(coord) = detail::Cast<ScalarType, CType>::apply(
|
||||
tensor_c.at(coord) = cutlass::reference::detail::Cast<ScalarType, CType>::apply(
|
||||
alpha * ScalarType(accum[i][j]) +
|
||||
beta * ScalarType(tensor_c.at(coord)));
|
||||
}
|
||||
@ -225,9 +149,16 @@ void Gemm(
|
||||
Gemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Batched GEMM
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a batch of GEMMs over a set of matrices of common dimension.
|
||||
//
|
||||
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
||||
//
|
||||
template <
|
||||
typename TensorRefCollectionA,
|
||||
typename TensorRefCollectionB,
|
||||
@ -235,14 +166,14 @@ template <
|
||||
typename ScalarType,
|
||||
typename AccumulatorType
|
||||
>
|
||||
void BatchGemm(
|
||||
void BatchedGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRefCollectionA const& tensor_a,
|
||||
TensorRefCollectionB const& tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRefCollectionC &tensor_c,
|
||||
AccumulatorType initial_accum = AccumulatorType(0)) {
|
||||
AccumulatorType initial_accum) {
|
||||
|
||||
typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin();
|
||||
typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin();
|
||||
@ -263,6 +194,29 @@ void BatchGemm(
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
||||
/// objects.
|
||||
//
|
||||
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
||||
//
|
||||
template <
|
||||
typename TensorRefCollectionA,
|
||||
typename TensorRefCollectionB,
|
||||
typename TensorRefCollectionC,
|
||||
typename ScalarType,
|
||||
typename AccumulatorType
|
||||
>
|
||||
void BatchedGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
ScalarType alpha,
|
||||
TensorRefCollectionA const& tensor_a,
|
||||
TensorRefCollectionB const& tensor_b,
|
||||
ScalarType beta,
|
||||
TensorRefCollectionC &tensor_c) {
|
||||
|
||||
BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
|
||||
254
tools/util/reference/host/split_complex_gemm.h
Normal file
254
tools/util/reference/host/split_complex_gemm.h
Normal file
@ -0,0 +1,254 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for split-complex GEMM in device-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
#include "cutlass/util/complex.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace host {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a complex-valued GEMM whose operands are in the split-complex format.
|
||||
template <
|
||||
typename TensorRefA, /// concept: ZipTensorRef
|
||||
typename TensorRefB, /// concept: ZipTensorRef
|
||||
typename TensorRefC, /// concept: ZipTensorRef
|
||||
typename ScalarType, /// real-valued type underlying complex scalars
|
||||
typename AccumulatorType /// real-valued type underlying complex accumulators
|
||||
>
|
||||
void SplitComplexGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
platform::complex<ScalarType> alpha,
|
||||
TensorRefA tensor_a,
|
||||
TensorRefB tensor_b,
|
||||
platform::complex<ScalarType> beta,
|
||||
TensorRefC tensor_c,
|
||||
platform::complex<AccumulatorType> initial_accum) {
|
||||
|
||||
typedef typename TensorRefA::First::Storage AType;
|
||||
typedef typename TensorRefB::First::Storage BType;
|
||||
typedef typename TensorRefC::First::Storage CType;
|
||||
|
||||
typedef platform::complex<AType> ComplexAType;
|
||||
typedef platform::complex<BType> ComplexBType;
|
||||
typedef platform::complex<CType> ComplexCType;
|
||||
typedef platform::complex<ScalarType> ComplexScalarType;
|
||||
typedef platform::complex<AccumulatorType> ComplexAccumulatorType;
|
||||
|
||||
static_assert(
|
||||
TensorRefA::First::kRank == 2 && TensorRefA::Second::kRank == 2 &&
|
||||
TensorRefB::First::kRank == 2 && TensorRefB::Second::kRank == 2 &&
|
||||
TensorRefC::First::kRank == 2 && TensorRefC::Second::kRank == 2,
|
||||
"Tensors must be of rank 2");
|
||||
|
||||
// Note: batch is ignored.
|
||||
int const M = problem_size.m();
|
||||
int const N = problem_size.n();
|
||||
int const K = problem_size.k();
|
||||
|
||||
// Blocking necessary to speedup reference implementation
|
||||
int const Mblock = 32;
|
||||
int const Nblock = 32;
|
||||
|
||||
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
||||
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
||||
|
||||
ComplexAccumulatorType accum[Mblock][Nblock];
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
accum[i][j] = initial_accum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < K; ++k_block) {
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
if (row < M && col < N) {
|
||||
|
||||
ComplexAType a(
|
||||
tensor_a.first.at(MatrixCoord(row, k_block)),
|
||||
tensor_a.second.at(MatrixCoord(row, k_block))
|
||||
);
|
||||
|
||||
ComplexBType b(
|
||||
tensor_b.first.at(MatrixCoord(k_block, col)),
|
||||
tensor_b.second.at(MatrixCoord(k_block, col))
|
||||
);
|
||||
|
||||
accum[i][j] = detail::inner_product(a, b, accum[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < Nblock; j++) {
|
||||
for (int i = 0; i < Mblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
if (row < M && col < N) {
|
||||
|
||||
ComplexScalarType product(
|
||||
detail::Cast<AccumulatorType, ScalarType>::apply(accum[i][j].real()),
|
||||
detail::Cast<AccumulatorType, ScalarType>::apply(accum[i][j].imag())
|
||||
);
|
||||
|
||||
ComplexScalarType source(
|
||||
detail::Cast<CType, ScalarType>::apply(tensor_c.first.at(coord)),
|
||||
detail::Cast<CType, ScalarType>::apply(tensor_c.second.at(coord))
|
||||
);
|
||||
|
||||
ComplexScalarType result = alpha * product + beta * source;
|
||||
|
||||
tensor_c.first.at(coord) = detail::Cast<ScalarType, CType>::apply(result.real());
|
||||
tensor_c.second.at(coord) = detail::Cast<ScalarType, CType>::apply(result.imag());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a complex-valued GEMM whose operands are in the split-complex format.
|
||||
template <
|
||||
typename TensorRefA, /// concept: ZipTensorRef
|
||||
typename TensorRefB, /// concept: ZipTensorRef
|
||||
typename TensorRefC, /// concept: ZipTensorRef
|
||||
typename ScalarType, /// real-valued type underlying complex scalars
|
||||
typename AccumulatorType /// real-valued type underlying complex accumulators
|
||||
>
|
||||
void SplitComplexGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
platform::complex<ScalarType> alpha,
|
||||
TensorRefA tensor_a,
|
||||
TensorRefB tensor_b,
|
||||
platform::complex<ScalarType> beta,
|
||||
TensorRefC tensor_c) {
|
||||
|
||||
return SplitComplexGemm(problem_size, alpha, tensor_a, tensor_b,beta, tensor_c, ScalarType(0));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Batched Split-Complex GEMM
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a complex-valued GEMM whose operands are in the split-complex format.
|
||||
template <
|
||||
typename TensorRefCollectionA, /// concept: Pair<TensorRefCollection, TensorRefCollection>
|
||||
typename TensorRefCollectionB, /// concept: Pair<TensorRefCollection, TensorRefCollection>
|
||||
typename TensorRefCollectionC, /// concept: Pair<TensorRefCollection, TensorRefCollection>
|
||||
typename ScalarType, /// real-valued type underlying complex scalars
|
||||
typename AccumulatorType /// real-valued type underlying complex accumulators
|
||||
>
|
||||
void BatchedSplitComplexGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
platform::complex<ScalarType> alpha,
|
||||
TensorRefCollectionA tensor_a,
|
||||
TensorRefCollectionB tensor_b,
|
||||
platform::complex<ScalarType> beta,
|
||||
TensorRefCollectionC tensor_c,
|
||||
platform::complex<AccumulatorType> initial_accum) {
|
||||
|
||||
typename TensorRefCollectionA::ConstIterator tensor_a_real = tensor_a.first.begin();
|
||||
typename TensorRefCollectionA::ConstIterator tensor_a_imag = tensor_a.second.begin();
|
||||
|
||||
typename TensorRefCollectionB::ConstIterator tensor_b_real = tensor_b.first.begin();
|
||||
typename TensorRefCollectionB::ConstIterator tensor_b_imag = tensor_b.second.begin();
|
||||
|
||||
typename TensorRefCollectionC::ConstIterator tensor_c_real = tensor_c.first.begin();
|
||||
typename TensorRefCollectionC::ConstIterator tensor_c_imag = tensor_c.second.begin();
|
||||
|
||||
for (int batch = 0; batch < problem_size.batch(); ++batch) {
|
||||
|
||||
SplitComplexGemm(
|
||||
problem_size,
|
||||
alpha,
|
||||
make_ZipTensorRef(*tensor_a_real, *tensor_a_imag),
|
||||
make_ZipTensorRef(*tensor_b_real, *tensor_b_imag),
|
||||
beta,
|
||||
make_ZipTensorRef(*tensor_c_real, *tensor_c_imag),
|
||||
initial_accum);
|
||||
|
||||
++tensor_a_real;
|
||||
++tensor_a_imag;
|
||||
++tensor_b_real;
|
||||
++tensor_b_imag;
|
||||
++tensor_c_real;
|
||||
++tensor_c_imag;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes a complex-valued GEMM whose operands are in the split-complex format.
|
||||
template <
|
||||
typename TensorRefCollectionA, /// concept: pair<TensorRefCollection, TensorRefCollection>
|
||||
typename TensorRefCollectionB, /// concept: pair<TensorRefCollection, TensorRefCollection>
|
||||
typename TensorRefCollectionC, /// concept: pair<TensorRefCollection, TensorRefCollection>
|
||||
typename ScalarType, /// real-valued type underlying complex scalars
|
||||
typename AccumulatorType /// real-valued type underlying complex accumulators
|
||||
>
|
||||
void BatchedSplitComplexGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
platform::complex<ScalarType> alpha,
|
||||
TensorRefCollectionA tensor_a,
|
||||
TensorRefCollectionB tensor_b,
|
||||
platform::complex<ScalarType> beta,
|
||||
TensorRefCollectionC tensor_c) {
|
||||
|
||||
BatchedSplitComplexGemm(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
beta,
|
||||
tensor_c,
|
||||
platform::complex<ScalarType>(0, 0));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace host
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
@ -45,6 +45,7 @@ struct TypeTraits {
|
||||
typedef T device_type;
|
||||
static inline T remove_negative_zero(T x) { return x; }
|
||||
static inline T to_print(T x) { return x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -56,6 +57,7 @@ struct TypeTraits<Vector<bin1_t, 32> > {
|
||||
typedef uint32_t unsigned_type;
|
||||
static inline uint32_t remove_negative_zero(uint32_t x) { return x; }
|
||||
static inline uint32_t to_print(uint32_t x) { return x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -67,6 +69,7 @@ struct TypeTraits< Vector<int4_t, 8> > {
|
||||
typedef uint32_t unsigned_type;
|
||||
static inline uint32_t remove_negative_zero(uint32_t x) { return x; }
|
||||
static inline uint32_t to_print(uint32_t x) { return x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -78,6 +81,7 @@ struct TypeTraits< Vector<uint4_t, 8> > {
|
||||
typedef uint32_t unsigned_type;
|
||||
static inline uint32_t remove_negative_zero(uint32_t x) { return x; }
|
||||
static inline uint32_t to_print(uint32_t x) { return x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -89,6 +93,7 @@ struct TypeTraits<int8_t> {
|
||||
typedef uint8_t unsigned_type;
|
||||
static inline int8_t remove_negative_zero(int8_t x) { return x; }
|
||||
static inline int to_print(int8_t x) { return (int)x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -100,6 +105,7 @@ struct TypeTraits<uint8_t> {
|
||||
typedef uint8_t unsigned_type;
|
||||
static inline uint8_t remove_negative_zero(uint8_t x) { return x; }
|
||||
static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -111,6 +117,7 @@ struct TypeTraits<int> {
|
||||
typedef uint32_t unsigned_type;
|
||||
static inline int32_t remove_negative_zero(int32_t x) { return x; }
|
||||
static inline int to_print(int x) { return x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -122,6 +129,7 @@ struct TypeTraits<unsigned> {
|
||||
typedef uint32_t unsigned_type;
|
||||
static inline uint32_t remove_negative_zero(uint32_t x) { return x; }
|
||||
static inline uint32_t to_print(uint32_t x) { return x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -140,6 +148,7 @@ struct TypeTraits<half> {
|
||||
return x;
|
||||
}
|
||||
static inline half to_print(half x) { return x; }
|
||||
static inline device_type to_device(half x) { return reinterpret_cast<device_type const &>(x); }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -151,6 +160,7 @@ struct TypeTraits<int64_t> {
|
||||
typedef uint64_t unsigned_type;
|
||||
static inline int64_t remove_negative_zero(int64_t x) { return x; }
|
||||
static inline int64_t to_print(int64_t x) { return x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -162,6 +172,7 @@ struct TypeTraits<uint64_t> {
|
||||
typedef uint64_t unsigned_type;
|
||||
static inline uint64_t remove_negative_zero(uint64_t x) { return x; }
|
||||
static inline uint64_t to_print(uint64_t x) { return x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -175,6 +186,7 @@ struct TypeTraits<cutlass::half_t> {
|
||||
return (x.raw() == 0x8000 ? half_t::bitcast(0) : x);
|
||||
}
|
||||
static inline half_t to_print(half_t x) { return x; }
|
||||
static inline device_type to_device(cutlass::half_t x) { return reinterpret_cast<device_type const &>(x); }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -186,6 +198,7 @@ struct TypeTraits<float> {
|
||||
typedef uint32_t unsigned_type;
|
||||
static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; }
|
||||
static inline float to_print(float x) { return x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -197,6 +210,7 @@ struct TypeTraits<double> {
|
||||
typedef uint64_t unsigned_type;
|
||||
static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; }
|
||||
static inline double to_print(double x) { return x; }
|
||||
static inline device_type to_device(host_type x) { return x; }
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -212,6 +226,7 @@ struct TypeTraits<platform::complex<half> > {
|
||||
typedef platform::complex<half> device_type;
|
||||
typedef int16_t integer_type;
|
||||
typedef uint16_t unsigned_type;
|
||||
static inline device_type to_device(platform::complex<half> x) { return reinterpret_cast<device_type const &>(x); }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -228,6 +243,7 @@ struct TypeTraits<platform::complex<half_t> > {
|
||||
);
|
||||
}
|
||||
static inline platform::complex<half_t> to_print(platform::complex<half_t> x) { return x; }
|
||||
static inline device_type to_device(platform::complex<half_t> x) { return reinterpret_cast<device_type const &>(x); }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -247,6 +263,7 @@ struct TypeTraits<platform::complex<float> > {
|
||||
}
|
||||
|
||||
static inline platform::complex<float> to_print(platform::complex<float> x) { return x; }
|
||||
static inline device_type to_device(platform::complex<float> x) { return reinterpret_cast<device_type const &>(x); }
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -263,6 +280,7 @@ struct TypeTraits<platform::complex<double> > {
|
||||
);
|
||||
}
|
||||
static inline platform::complex<double> to_print(platform::complex<double> x) { return x; }
|
||||
static inline device_type to_device(platform::complex<double> x) { return reinterpret_cast<device_type const &>(x); }
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user