CUTLASS 2.3 initial commit (#134)

CUTLASS 2.3 adds GEMMs targeting Sparse Tensor Cores on the NVIDIA Ampere Architecture, fast SGEMM, and small matrix classes, bug fixes, and performance enhancements.
This commit is contained in:
Andrew Kerr
2020-09-23 14:00:58 -07:00
committed by GitHub
parent 4dac7490e6
commit c53f3339bb
209 changed files with 46922 additions and 1677 deletions

View File

@ -2,6 +2,20 @@
# CUTLASS 2.x
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
* [Sparse Tensor Core GEMM kernels](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu):
* Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
* Fast SGEMM targeting GeForce RTX 30-series CUDA Cores
* Minor Features:
* [Activation functions](/include/cutlass/epilogue/thread/activation.h) such as [GeLU](/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](/include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
* Small [matrix](/include/cutlass/matrix.h) and [quaternion](/include/cutlass/quaternion.h) template classes in device code
* [Floating-point constants](/include/cutlass/constants.h)
* NVIDIA Ampere GPU Architecture examples and documentation:
* [Tensor Float 32](/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
* [Sparse Tensor Cores](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
* Documentation added on CUTLASS [efficient row-major epilogue](/media/docs/gemm_api.md#efficient-epilogue)
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
* Fast Tensor Core operations:

View File

@ -32,7 +32,7 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
project(CUTLASS VERSION 2.2.0 LANGUAGES CXX)
project(CUTLASS VERSION 2.3.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
find_package(Doxygen QUIET)
@ -69,6 +69,8 @@ endif()
set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples")
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_TOOLS} CACHE BOOL "Enable CUTLASS Library")
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_TOOLS} CACHE BOOL "Enable CUTLASS Profiler")
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT})
@ -101,6 +103,9 @@ endif()
if (NOT CUDA_VERSION VERSION_LESS 11.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
endif()
if (NOT CUDA_VERSION VERSION_LESS 11.1)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86)
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
@ -164,12 +169,14 @@ set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.
#
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.")
# Test Levels L0, L1, L2
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
#
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
@ -181,6 +188,11 @@ else()
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON)
endif()
# Trace levels for debugging
set(CUTLASS_DEBUG_TRACE_LEVEL "0" CACHE STRING "Level of debug tracing to perform.")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_DEBUG_TRACE_LEVEL=${CUTLASS_DEBUG_TRACE_LEVEL})
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
"Enable PTX mma instruction for collective matrix multiply operations.")
@ -352,7 +364,7 @@ set_target_properties(CUTLASS PROPERTIES EXPORT_NAME cutlass)
set(CUTLASS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library")
set(CUTLASS_GENERATOR_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/library/)
set(CUTLASS_GENERATOR_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/library CACHE INTERNAL "Location of generator scripts")
# The following utility directory is needed even if the tools build is disabled, so it exists here.
set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/include CACHE INTERNAL "")

View File

@ -16,6 +16,9 @@ Naila Farooqui
Piotr Majcher
Paul Springer
Jin Wang
Aniket Shivam
Chinmay Talegaonkar
Shang Zhang
Scott Yokim
Markus Hohnerbach
Aditya Atluri
@ -52,6 +55,8 @@ Olivier Giroux
Stephen Jones
Rishkul Kulkarni
Bryce Lelbach
Matthew Nicely
Joel McCormack
Kyrylo Perelygin

View File

@ -213,7 +213,14 @@ function(cutlass_correct_source_file_language_property)
endif()
endfunction()
set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation")
# If building with all kernels, set UNITY build on by default.
if (CUTLASS_LIBRARY_KERNELS MATCHES "all")
set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON)
else()
set(CUTLASS_UNITY_BUILD_ENABLED_INIT OFF)
endif()
set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation")
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
function(cutlass_unify_source_files TARGET_ARGS_VAR)

View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 2.2
# CUTLASS 2.3
_CUTLASS 2.2 - June 2020_
_CUTLASS 2.3 - September 2020_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
@ -30,6 +30,14 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
See the [functionality listing](media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
# What's New in CUTLASS 2.3
CUTLASS 2.3 is a minor update to CUTLASS adding:
- GEMMs targeting structured [Sparse Tensor Cores](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) in NVIDIA Ampere Architecture GPUs
- Fast SGEMM kernels targeting GeForce RTX 30-series CUDA Cores
- Intended to be compiled with [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit)
- See the [CHANGELOG](CHANGELOG.md) for more details.
# What's New in CUTLASS 2.2
CUTLASS 2.2 is a significant update to CUTLASS adding:
@ -42,7 +50,7 @@ CUTLASS 2.2 is a significant update to CUTLASS adding:
# What's New in CUTLASS 2.1
CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
CUTLASS 2.1 is a minor update to CUTLASS adding:
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
@ -71,8 +79,8 @@ using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
# Compatibility
CUTLASS requires a C++11 host compiler and
performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit).
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2.
performs best when built with the [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit).
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, CUDA 10.2, and CUDA 11.0.
We have tested the following environments.
@ -99,10 +107,11 @@ any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GP
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|NVIDIA Tesla T4|7.5|10.0|10.2|
|NVIDIA A100|8.0|11.0|11.0|
|NVIDIA GeForce 3090|8.6|11.1|11.1|
# Documentation
CUTLASS 2.2 is described in the following documents and the accompanying
CUTLASS is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
@ -136,14 +145,14 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
```
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, and 8.6. To reduce compile time you can specify
the architectures to build CUTLASS for by changing the CMake configuration setting
`CUTLASS_NVCC_ARCHS`.
```
$ mkdir build && cd build
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 # compiles for NVIDIA's Turing GPU architecture
$ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA's Ampere Architecture
```
From the `build/` directory, compile and run the CUTLASS unit tests by building the target `test_unit` with make.
@ -258,15 +267,25 @@ The `tools/profiler/` directory contains a command-line utility for launching ea
It can be built as follows:
```
$ make cutlass_profiler -j
$ make cutlass_profiler -j16
```
To limit compilation time, only one tile size is instantiated for each data type, math instruction, and layout.
By default, only one tile size is instantiated for each data type, math instruction, and layout.
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
Beware, this results in *thousands* of kernels and long build times.
```
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all
...
$ make cutlass_profiler -j
$ make cutlass_profiler -j16
```
To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with
wildcard characters may be reduce the set of kernels. The following builds exactly one kernel:
```
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sgemm_128x128_8x2_nn_align1
...
$ make cutlass_profiler -j16
```
Example command line for profiling SGEMM kernels is as follows:

View File

@ -69,7 +69,7 @@
template <typename Element, typename GmemIterator, typename SmemIterator>
__global__ void kernel_dump(typename GmemIterator::Params params,
typename GmemIterator::TensorRef ref) {
__shared__ Element shared_storage[EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL];
extern __shared__ Element shared_storage[];
// Construct the global iterator and load the data to the fragments.
int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x;
@ -164,8 +164,11 @@ int main() {
dim3 grid(1, 1);
dim3 block(32, 1, 1);
int smem_size =
int(sizeof(Element) * EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL);
kernel_dump<Element, GmemIterator, SmemIterator>
<<<grid, block>>>(params, matrix.device_ref());
<<<grid, block, smem_size, 0>>>(params, matrix.device_ref());
cudaError_t result = cudaDeviceSynchronize();

View File

@ -50,7 +50,7 @@
To build strictly the planar complex kernels needed for general application, execute the following
CMake command in an empty build directory.
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex
This builds all planar complex GEMM variants for Volta and Turing architectures.
@ -59,7 +59,7 @@
specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for
the 'CN' layout configuration (conjugate A operand with both A and B as column-major).
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_f16*cn
$ make 10_planar_complex
@ -526,6 +526,11 @@ int main(int argc, char const **args) {
return 0;
}
}
else {
// NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond.
//
// fall through
}
//
// Parse options

View File

@ -48,7 +48,7 @@
To build strictly the planar complex kernels needed for general application, execute the following
CMake command in an empty build directory.
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex
This builds all planar complex GEMM variants for Volta and Turing architectures.
@ -57,7 +57,7 @@
specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for
the 'CN' layout configuration (conjugate A operand with both A and B as column-major).
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_array_f16*cn
$ make 11_planar_complex_array
@ -586,6 +586,11 @@ int main(int argc, char const **args) {
return 0;
}
}
else {
// NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond.
//
// fall through
}
//
// Parse options

View File

@ -0,0 +1,205 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "device/b2b_gemm.h"
#include "b2b_interleaved_gemm_run.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
void run_nonfused_gemm_s8_sm80() {
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementCompute = float;
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
ElementCompute alpha0 = ElementCompute(2);
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2);
ElementCompute beta1 = ElementCompute(0);
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using Gemm0 = cutlass::gemm::device::Gemm<
int8_t,
cutlass::layout::ColumnMajorInterleaved<32>,
int8_t,
cutlass::layout::RowMajorInterleaved<32>,
ElementOutput,
cutlass::layout::ColumnMajorInterleaved<32>,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape0,
WarpShape0,
InstructionShape,
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
16,
16,
false,
cutlass::arch::OpMultiplyAddSaturate,
true
>;
using Gemm1 = cutlass::gemm::device::Gemm<
int8_t,
cutlass::layout::ColumnMajorInterleaved<32>,
int8_t,
cutlass::layout::RowMajorInterleaved<32>,
ElementOutput,
cutlass::layout::ColumnMajorInterleaved<32>,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape1,
WarpShape1,
InstructionShape,
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
16,
16,
false,
cutlass::arch::OpMultiplyAddSaturate,
true
>;
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
if(pass)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
}
void run_fused_gemm_s8_sm80() {
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementCompute = float;
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
ElementCompute alpha0 = ElementCompute(2);
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2);
ElementCompute beta1 = ElementCompute(0);
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
ElementAccumulator,
ElementCompute
>;
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute
>;
using B2bGemm = cutlass::gemm::device::B2bGemm<
int8_t,
cutlass::layout::ColumnMajorInterleaved<32>,
int8_t,
cutlass::layout::RowMajorInterleaved<32>,
ElementOutput,
cutlass::layout::ColumnMajorInterleaved<32>,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
16,
16,
false,
cutlass::arch::OpMultiplyAddSaturate,
true
>;
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n";
bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -38,6 +38,8 @@
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "helper.h"
#define CHECK_GT(val1, val2) \
@ -115,7 +117,9 @@ struct B2bInterleavedNonFusedGemmRun
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true) {
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
@ -232,6 +236,13 @@ struct B2bInterleavedNonFusedGemmRun
status = gemm_op_1.initialize(arguments_1);
CUTLASS_CHECK(status);
for(int i = 0; i < warm_ups; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
status = gemm_op_1();
CUTLASS_CHECK(status);
}
//
// Run the GEMM
//
@ -242,14 +253,14 @@ struct B2bInterleavedNonFusedGemmRun
cudaEventRecord(start);
for(int i = 0; i < 100; i++) {
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < 100; i++) {
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
@ -261,9 +272,9 @@ struct B2bInterleavedNonFusedGemmRun
cudaEventElapsedTime(&gemm0Time, start, stop1);
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
cudaEventElapsedTime(&totalTime, start, stop2);
std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n";
std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n";
std::cout << "total time " << totalTime / 100.0 << " ms\n";
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
std::cout << "total time " << totalTime / (float)runs << " ms\n";
tensor_D0.sync_host();
tensor_D1.sync_host();
@ -302,7 +313,7 @@ struct B2bInterleavedNonFusedGemmRun
reference_gemm_1(
problem_size_1,
alpha1,
tensor_D0.device_ref(),
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
tensor_C1.device_ref(),
@ -420,7 +431,9 @@ struct B2bInterleavedFusedGemmRun
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true) {
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
@ -478,7 +491,7 @@ struct B2bInterleavedFusedGemmRun
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
//Reorder B0
cutlass::reorder_column<B2bGemm::InstructionShape::kK>(
cutlass::reorder_column<16>(
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
cutlass::reorder_column<InterleavedK_>(
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
@ -526,6 +539,11 @@ struct B2bInterleavedFusedGemmRun
CUTLASS_CHECK(status);
for(int i = 0; i < warm_ups; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
//
// Run the GEMM
//
@ -536,7 +554,7 @@ struct B2bInterleavedFusedGemmRun
cudaEventRecord(start);
for(int i = 0; i < 100; i++) {
for(int i = 0; i < runs; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
@ -546,7 +564,7 @@ struct B2bInterleavedFusedGemmRun
cudaDeviceSynchronize();
float gemmTime;
cudaEventElapsedTime(&gemmTime, start, stop);
std::cout << "time " << gemmTime / 100.0 << " ms\n";
std::cout << "time " << gemmTime / (float)runs << " ms\n";
//tensor_D0.sync_host();
tensor_D1.sync_host();

View File

@ -30,7 +30,6 @@ two unfused GEMM operations, demonstrating a speedup of the fused kernel on the
NVIDIA Turing GPU architecture.
Problem size:
GEMM1 (M,N,K): 128*1600, 64, 576
GEMM2 (M,N,K): 128*1600, 128, 64
@ -42,16 +41,17 @@ also requires warp_tile_N = thread_block_tile_N so the data required by each war
register-file-resident.
Performance:
- fp16 on Tesla T4 @ 1590MHz (non-fused vs. fused): 1.39011 ms vs. 1.26035 ms
- int8 on Tesla T4 @ 1590MHz (non-fused vs. fused): 0.751759 ms vs. 0.62971 ms
- fp16 on Quadro RTX 8000 @ 1890MHz (non-fused vs. fused): 0.721144 ms vs. 0.629864 ms
- int8 on Quadro RTX 8000 @ 1890MHz (non-fused vs. fused): 0.379049 ms vs. 0.324764 ms
- int8 on GA100 @ 1200MHz (non-fused vs. fused): 0.153795 ms vs. 0.129874 ms
*/
#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h"
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h"
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h"
int run() {
@ -71,7 +71,10 @@ int run() {
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
run_nonfused_gemm_s8_sm80();
run_fused_gemm_s8_sm80();
#elif defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
run_nonfused_gemm_f16();
run_fused_gemm_f16();
run_nonfused_gemm_s8();

View File

@ -210,7 +210,8 @@ struct B2bGemm {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
@ -313,7 +314,8 @@ struct B2bGemm {
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
//assume identity swizzle
MatrixCoord threadblock_offset(

View File

@ -217,7 +217,85 @@ struct DefaultB2bGemm<
};
/// Partial specialization for Turing IMMA Interleaved layout
/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
template <
/// Element type for A matrix operand
typename ElementA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp0,
/// Epilogue output operator
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Is Beta zero or not
bool IsBetaZero>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
SplitKSerial, Operator, IsBetaZero> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
using ElementAccumulator = int32_t;
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0,
true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
using Epilogue = typename cutlass::epilogue::threadblock::
DefaultInterleavedEpilogueTensorOp<
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
64 / sizeof_bits<ElementC>::value, InterleavedK,
IsBetaZero>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Turing Integer Tensor Core Interleaved layout
template <
/// Element type for A matrix operand
typename ElementA,

View File

@ -0,0 +1,862 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
#include "threadblock/b2b_mma_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape0_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA0_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA0_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA0,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB0_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB0_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB0,
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape1_,
/// Iterates over the intermediate accumulator tile
// (concept::MmaTensorOpFragmentIterator)
typename FragmentIteratorA1_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB1_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB1_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB1,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
typename OutputOp_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy0_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy1_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class B2bMmaMultistage :
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages> {
public:
///< Base class
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape0 = Shape0_;
///< Iterates over tiles of A operand in global memory
using IteratorA0 = IteratorA0_;
///< Iterates over tiles of B operand in global memory
using IteratorB0 = IteratorB0_;
///< Policy describing tuning details
using Policy0 = Policy0_;
using SmemIteratorA0 = SmemIteratorA0_;
using SmemIteratorB0 = SmemIteratorB0_;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape1 = Shape1_;
///< Iterates over intermediate accumulator tile
using FragmentIteratorA1 = FragmentIteratorA1_;
///< Iterates over tiles of B operand in global memory
using IteratorB1 = IteratorB1_;
///< Policy describing tuning details
using Policy1 = Policy1_;
using SmemIteratorB1 = SmemIteratorB1_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Epilogue after 1st Gemm
using OutputOp = OutputOp_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC0 = typename Policy0::Operator::FragmentC;
/// Warp-level Mma
using Operator0 = typename Policy0::Operator;
/// Fragment of accumulator tile
using FragmentC1 = typename Policy1::Operator::FragmentC;
/// Warp-level Mma
using Operator1 = typename Policy1::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on A operand
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
/// Complex transform on B operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
static_assert(Base::kWarpGemmIterations0 > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
static_assert(Base::kWarpGemmIterations1 > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const TBLDGSTSIterationsA0 =
IteratorA0::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const TBLDGSTSIterationsB0 =
IteratorB0::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const TBLDGSTSIterationsB1 =
IteratorB1::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA0 =
(TBLDGSTSIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB0 =
(TBLDGSTSIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB1 =
(TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
};
private:
using WarpLoadedFragmentA0 = typename Operator0::FragmentA;
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
/// Warp Fragment of operand A1 loaded from accmulator tile
using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment;
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA0 smem_iterator_A0_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB0 smem_iterator_B0_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB1 smem_iterator_B1_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
B2bMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::B2bMmaSharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A0_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k});
this->warp_tile_iterator_B0_.add_tile_offset(
{Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n});
this->warp_tile_iterator_B1_.add_tile_offset(
{Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0,
int group_start_A0 = 0, int group_start_B0 = 0) {
iterator_A0.set_iteration_index(group_start_A0 *
IteratorA0::kAccessesPerVector);
this->smem_iterator_A0_.set_iteration_index(group_start_A0);
// LDGSTS for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) {
if (group_start_A0 + j < Detail::TBLDGSTSIterationsA0) {
typename IteratorA0::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA0::AccessType *>(
this->smem_iterator_A0_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA0::Element>::value *
IteratorA0::ThreadMap::kElementsPerAccess /
IteratorA0::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A0.get();
cutlass::arch::cp_async<kSrcBytes, kCacheOpA0>(
dst_ptr + v, gmem_ptr, iterator_A0.valid());
++iterator_A0;
}
++this->smem_iterator_A0_;
}
}
iterator_B0.set_iteration_index(group_start_B0 *
IteratorB0::kAccessesPerVector);
this->smem_iterator_B0_.set_iteration_index(group_start_B0);
// LDGSTS for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) {
if (group_start_B0 + j < Detail::TBLDGSTSIterationsB0) {
typename IteratorB0::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB0::AccessType *>(
this->smem_iterator_B0_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
IteratorB0::ThreadMap::kElementsPerAccess /
IteratorB0::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B0.get();
cutlass::arch::cp_async<kSrcBytes, kCacheOpB0>(
dst_ptr + v, gmem_ptr, iterator_B0.valid());
++iterator_B0;
}
++this->smem_iterator_B0_;
}
}
}
CUTLASS_DEVICE
void copy_tiles_and_advance_1(IteratorB1 &iterator_B1,
int group_start_B1 = 0) {
iterator_B1.set_iteration_index(group_start_B1 *
IteratorB1::kAccessesPerVector);
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
// LDGSTS for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) {
typename IteratorB1::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB1::AccessType *>(
this->smem_iterator_B1_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
IteratorB1::ThreadMap::kElementsPerAccess /
IteratorB1::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B1.get();
cutlass::arch::cp_async<kSrcBytes, kCacheOpB1>(
dst_ptr + v, gmem_ptr, iterator_B1.valid());
++iterator_B1;
}
++this->smem_iterator_B1_;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations_0,
///< destination accumulator tile
FragmentC1 &accum,
///< iterator over A operand in global memory
IteratorA0 iterator_A0,
///< iterator over B operand in global memory
IteratorB0 iterator_B0,
///< iterator over B operand in global memory
IteratorB1 iterator_B1,
///< initial value of accumulator
FragmentC0 const &src_accum,
///< epilogue operation after 1st Gemm
OutputOp output_op_0)
{
//
// Prologue
//
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations_0) {
if (gemm_k_iterations_0 == 0) {
iterator_A0.clear_mask();
iterator_B0.clear_mask();
}
iterator_A0.set_iteration_index(0);
this->smem_iterator_A0_.set_iteration_index(0);
// LDGSTS for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::TBLDGSTSIterationsA0; ++j) {
typename IteratorA0::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA0::AccessType *>(
this->smem_iterator_A0_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorA0::Element>::value *
IteratorA0::ThreadMap::kElementsPerAccess /
IteratorA0::kAccessesPerVector / 8;
int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
dst_ptr + v, iterator_A0.get(), iterator_A0.valid());
++iterator_A0;
}
++this->smem_iterator_A0_;
}
iterator_B0.set_iteration_index(0);
this->smem_iterator_B0_.set_iteration_index(0);
// LDGSTS for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::TBLDGSTSIterationsB0; ++j) {
typename IteratorB0::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB0::AccessType *>(
this->smem_iterator_B0_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB0::Element>::value *
IteratorB0::ThreadMap::kElementsPerAccess /
IteratorB0::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
++iterator_B0;
}
++this->smem_iterator_B0_;
}
// Move to the next stage
iterator_A0.add_tile_offset({0, 1});
iterator_B0.add_tile_offset({1, 0});
this->smem_iterator_A0_.add_tile_offset({0, 1});
this->smem_iterator_B0_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Perform accumulation in the 'd' output operand
FragmentC0 accum0 = src_accum;
// DEPBAR+SYNC
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpLoadedFragmentA0 warp_loaded_frag_A0[2];
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
WarpTransformedFragmentA0 warp_transformed_frag_A0[2];
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
Operator0 warp_mma0;
this->warp_tile_iterator_A0_.set_kgroup_index(0);
this->warp_tile_iterator_B0_.set_kgroup_index(0);
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]);
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
++this->warp_tile_iterator_A0_;
++this->warp_tile_iterator_B0_;
if (gemm_k_iterations_0 == 0) {
iterator_A0.clear_mask();
iterator_B0.clear_mask();
}
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0],
warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]);
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations_0 > (-Base::kStages + 1);) {
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A0_;
++this->warp_tile_iterator_B0_;
if (warp_mma_k > 0)
warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2],
warp_transformed_frag_B0[warp_mma_k % 2],
warp_loaded_frag_A0[warp_mma_k % 2],
warp_loaded_frag_B0[warp_mma_k % 2]);
warp_mma0(
accum0,
warp_transformed_frag_A0[warp_mma_k % 2],
warp_transformed_frag_B0[warp_mma_k % 2],
accum0
);
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations0 - 1) {
int group_start_iteration_A0, group_start_iteration_B0;
group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0;
group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0;
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
group_start_iteration_B0);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations0) {
int group_start_iteration_A0, group_start_iteration_B0;
group_start_iteration_A0 =
(warp_mma_k + 1) * Detail::kAccessesPerGroupA0;
group_start_iteration_B0 =
(warp_mma_k + 1) * Detail::kAccessesPerGroupB0;
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
group_start_iteration_B0);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages have committed.
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next stage
iterator_A0.add_tile_offset({0, 1});
iterator_B0.add_tile_offset({1, 0});
this->smem_iterator_A0_.add_tile_offset({0, 1});
this->smem_iterator_B0_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1)) {
this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
} else {
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1)) {
this->warp_tile_iterator_A0_.add_tile_offset(
{0, -Base::kStages * Policy0::kPartitionsK *
Base::kWarpGemmIterations0});
this->warp_tile_iterator_B0_.add_tile_offset(
{-Base::kStages * Policy0::kPartitionsK *
Base::kWarpGemmIterations0,
0});
smem_read_stage_idx = 0;
} else {
++smem_read_stage_idx;
}
--gemm_k_iterations_0;
if (gemm_k_iterations_0 == 0) {
iterator_A0.clear_mask();
iterator_B0.clear_mask();
}
}
// Do any conversions feeding the first stage at the end of the loop so
// we can start right away on mma instructions
if (warp_mma_k + 1 == Base::kWarpGemmIterations0)
warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2],
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
warp_loaded_frag_A0[(warp_mma_k + 1) % 2],
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
}
}
// 2nd Gemm
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
//
// Prologue
//
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations_1) {
if (gemm_k_iterations_1 == 0) {
// iterator_A1.clear_mask();
iterator_B1.clear_mask();
}
#if 0
iterator_A1.set_iteration_index(0);
this->smem_iterator_A1_.set_iteration_index(0);
// LDGSTS for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::TBLDGSTSIterationsA1; ++j) {
typename IteratorA1::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA1::AccessType *>(
this->smem_iterator_A1_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA1::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorA1::Element>::value *
IteratorA1::ThreadMap::kElementsPerAccess /
IteratorA1::kAccessesPerVector / 8;
int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
dst_ptr + v, iterator_A0.get(), iterator_A0.valid());
++iterator_A0;
}
++this->smem_iterator_A0_;
}
#endif
iterator_B1.set_iteration_index(0);
this->smem_iterator_B1_.set_iteration_index(0);
// LDGSTS for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) {
typename IteratorB1::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB1::AccessType *>(
this->smem_iterator_B1_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB1::Element>::value *
IteratorB1::ThreadMap::kElementsPerAccess /
IteratorB1::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
++iterator_B1;
}
++this->smem_iterator_B1_;
}
// Move to the next stage
//iterator_A1.add_tile_offset({0, 1});
iterator_B1.add_tile_offset({1, 0});
//this->smem_iterator_A1_.add_tile_offset({0, 1});
this->smem_iterator_B1_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Perform accumulation in the 'd' output operand
// FragmentC0 accum0 = src_accum;
// DEPBAR+SYNC
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
Operator1 warp_mma1;
// this->warp_tile_iterator_A1_.set_kgroup_index(0);
this->warp_tile_iterator_B1_.set_kgroup_index(0);
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], output_op_0);
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
++warp_tile_iterator_A1_;
++this->warp_tile_iterator_B1_;
if (gemm_k_iterations_1 == 0) {
// iterator_A1.clear_mask();
iterator_B1.clear_mask();
}
smem_write_stage_idx = Base::kStages - 1;
smem_read_stage_idx = 0;
warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0],
warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]);
//
// Mainloop
//
CUTLASS_PRAGMA_UNROLL
for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1);
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
// this->warp_tile_iterator_A1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
++warp_tile_iterator_A1_;
++this->warp_tile_iterator_B1_;
if (warp_mma_k > 0)
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
warp_transformed_frag_B1[warp_mma_k % 2],
warp_loaded_frag_A1[warp_mma_k % 2],
warp_loaded_frag_B1[warp_mma_k % 2]);
warp_mma1(
accum,
warp_transformed_frag_A1[warp_mma_k % 2],
warp_transformed_frag_B1[warp_mma_k % 2],
accum
);
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations1 - 1) {
int group_start_iteration_B1;
group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1;
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
int group_start_iteration_B1;
group_start_iteration_B1 =
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages have committed.
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next stage
iterator_B1.add_tile_offset({1, 0});
this->smem_iterator_B1_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1)) {
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
} else {
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1)) {
this->warp_tile_iterator_B1_.add_tile_offset(
{-Base::kStages * Policy0::kPartitionsK *
Base::kWarpGemmIterations1,
0});
smem_read_stage_idx = 0;
} else {
++smem_read_stage_idx;
}
// --gemm_k_iterations_1;
if (gemm_k_iterations_1 == 1) {
iterator_B1.clear_mask();
}
}
// Do any conversions feeding the first stage at the end of the loop so
// we can start right away on mma instructions
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -48,10 +48,6 @@ namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////////////////////
template<int a>
struct chk_val {
static_assert(a==0, "check value");
};
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <

View File

@ -40,6 +40,7 @@
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
#include "threadblock/b2b_mma_pipelined.h"
#include "threadblock/b2b_mma_multistage.h"
////////////////////////////////////////////////////////////////////////////////
@ -200,8 +201,6 @@ template <
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
@ -220,7 +219,7 @@ template <
int InterleavedK>
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
kAlignmentB, ElementAccumulator,
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, 2, Operator, EpilogueOutputOp, true> {
// Define the MmaCore components
@ -251,7 +250,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>, ElementB,
LayoutB, 0, typename MmaCore0::IteratorThreadMapB>;
// Use fragment iterator for A operand
// Use fragment iterator for A1 operand
using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true
using FragmentIteratorA1 =
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
@ -282,6 +281,111 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
////////////////////////////////////////////////////////////////////////////////
/// Specialization for column-major-interleaved output
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Number of stages used in the multistage mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Number of Interleaved K
int InterleavedK>
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
kAlignmentB, ElementAccumulator,
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp, true> {
// Define the MmaCore components
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
ElementB, LayoutB, ElementAccumulator,
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
Operator, true>;
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
ElementB, LayoutB, ElementAccumulator,
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
Operator, true>;
// Define iterators over tiles from the A operand
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA0 =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB0 =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>;
// Use fragment iterator for A1 operand
using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true
using FragmentIteratorA1 =
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
MmaCore1::Shape::kK, //kBlocksColumn
ElementAccumulator, ElementA, AccumulatorLayout,
InstructionShape, EpilogueOutputOp, true /*only handle beta=0 for 1st Gemm epilogue*/>;
// Define iterators over tiles from the B operand
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
using IteratorB1 =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage<
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
MmaCore0::kCacheOpA,
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
typename MmaCore1::Shape, FragmentIteratorA1,
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
EpilogueOutputOp,
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,27 @@
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this list of
# conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice, this list of
# conditions and the following disclaimer in the documentation and/or other materials
# provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
14_ampere_tf32_tensorop_gemm
ampere_tf32_tensorop_gemm.cu
)

View File

@ -0,0 +1,278 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
Please check example 07 and 08 for the basics of tensor op gemm kernels. On NVIDIA Ampere
architecture, most concept still holds. The two main differences are
1. NVIDIA Ampere architecture introduces a new series of tensor core instructions (see
include/cutlass/arch/mma_sm80.h) which are more efficient on Ampere.
2. NVIDIA Ampere architecture uses cp_async() to build multistage software pipeline to better hide
latency (see include/cutlass/gemm/threadblock/mma_multistage.h)
Moreover, NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h)
data types in tensor cores. One big advantage is that we can load in fp32 data and convert them
implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate traditional
fp32 data by using NVIDIA Ampere architecture.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = float; // <- data type of elements in input matrix A
using ElementInputB = float; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices. Column Major for
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 4;
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;
int run() {
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 11.0.
//
// CUTLASS must be compiled with CUDA 11 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
return -1;
}
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (!((props.major * 10 + props.minor) >= 80)) {
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
// Return 0 so tests are considered passing if run on unsupported platforms.
return 0;
}
const int length_m = 5120;
const int length_n = 4096;
const int length_k = 4096;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// CUTLASS kernel
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// reference kernel
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(4),
ElementOutput(-4),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
// Initialize alpha and beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
tensor_a.device_ref(), // <- reference to matrix A on device
tensor_b.device_ref(), // <- reference to matrix B on device
tensor_c.device_ref(), // <- reference to matrix C on device
tensor_d.device_ref(), // <- reference to matrix D on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Initialize CUTLASS kernel with arguments and workspace pointer
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
// Launch initialized CUTLASS kernel
status = gemm_op();
CUTLASS_CHECK(status);
// Create instantiation for device reference gemm kernel
cutlass::reference::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementComputeEpilogue>
gemm_device;
// Launch device reference gemm kernel
gemm_device(problem_size,
alpha,
tensor_a.device_ref(),
tensor_b.device_ref(),
beta,
tensor_c.device_ref(),
tensor_ref_d.device_ref());
// Wait for kernels to finish
cudaDeviceSynchronize();
// Copy output data from CUTLASS and reference kernel to host for comparison
tensor_d.sync_host();
tensor_ref_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
std::cout << (passed ? "Passed" : "Failed") << std::endl;
return (passed ? 0 : -1);
}
int main() {
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 11.0.
//
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
// Returning zero so this test passes when built on older Toolkits.
return 0;
}
else {
return run();
}
}

View File

@ -0,0 +1,27 @@
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this list of
# conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice, this list of
# conditions and the following disclaimer in the documentation and/or other materials
# provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
15_ampere_sparse_tensorop_gemm
ampere_sparse_tensorop_gemm.cu
)

View File

@ -0,0 +1,311 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere
architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4.
Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of
meta data is different for every data types. CUTLASS templates can automatically infer it based on
input A and B. Check code below.
Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers
efficiently.
*/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_sparse.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/util/host_uncompress.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = int32_t; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B
using ElementOutput = int32_t; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices. Column Major for
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<256, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 3;
using Gemm = cutlass::gemm::device::SparseGemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;
// Data type and layout of meta data matrix E can be inferred from template Gemm.
using ElementInputE = typename Gemm::ElementE;
using LayoutInputE = typename Gemm::LayoutE;
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
// 50% Sparsity on Ampere
constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
int run() {
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 11.1.
//
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
return -1;
}
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (!((props.major * 10 + props.minor) >= 80)) {
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
// Return 0 so tests are considered passing if run on unsupported platforms.
return 0;
}
const int length_m = 512;
const int length_n = 512;
const int length_k = 1024;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2)
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a_uncompressed(
problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// CUTLASS kernel
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// reference kernel
// Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing.
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
// Same size as the above. The above one needs to be reordered and stored in this one.
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e_reordered(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(1),
ElementInputA(-1),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(1),
ElementInputB(-1),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(1),
ElementOutput(-1),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomSparseMeta(
tensor_e.host_view(),
1,
kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
// Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core
// instructions.
cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(),
{problem_size.m(), problem_size.n(),
problem_size.k() / kSparse / kElementsPerElementE});
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_e_reordered.sync_device();
tensor_ref_d.sync_device();
// Initialize alpha and beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
tensor_a.device_ref(), // <- reference to matrix A on device
tensor_b.device_ref(), // <- reference to matrix B on device
tensor_c.device_ref(), // <- reference to matrix C on device
tensor_d.device_ref(), // <- reference to matrix D on device
tensor_e.device_ref(), // <- reference to matrix E on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Initialize CUTLASS kernel with arguments and workspace pointer
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
// Launch initialized CUTLASS kernel
status = gemm_op();
CUTLASS_CHECK(status);
// uncompress tensor_a based on meta data tensor_e. We need it for reference computing.
cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(),
tensor_e.host_ref(), problem_size.m(), problem_size.k());
// Create instantiation for host reference gemm kernel
cutlass::reference::host::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementComputeEpilogue,
typename Gemm::Operator>
gemm_host;
// Launch host reference gemm kernel
gemm_host(problem_size,
alpha,
tensor_a_uncompressed.host_ref(),
tensor_b.host_ref(),
beta,
tensor_c.host_ref(),
tensor_ref_d.host_ref());
// Copy output data from CUTLASS host for comparison
tensor_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
std::cout << (passed ? "Passed" : "Failed") << std::endl;
return (passed ? 0 : -1);
}
int main() {
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 11.1.
//
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
// Returning zero so this test passes when built on older Toolkits.
return 0;
}
else {
return run();
}
}

View File

@ -71,6 +71,8 @@ foreach(EXAMPLE
11_planar_complex_array
12_gemm_bias_relu
13_fused_two_gemms
14_ampere_tf32_tensorop_gemm
15_ampere_sparse_tensorop_gemm
)
add_subdirectory(${EXAMPLE})

View File

@ -55,6 +55,17 @@ struct Sm75 {
struct Sm80 {
static int const kMinComputeCapability = 80;
};
struct Sm86 {
static int const kMinComputeCapability = 86;
};
/// Triggers a breakpoint on the device
CUTLASS_DEVICE
void device_breakpoint() {
#if defined(__CUDA_ARCH__)
asm volatile (" brkpt;\n");
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -51,6 +51,8 @@ struct global_load;
/////////////////////////////////////////////////////////////////////////////////////////////////
// The redundant mov PTX instruction is used to enforce the compiler to
// initialize data to zero before ld.global
template <typename AccessType
>
struct global_load<AccessType,
@ -83,7 +85,6 @@ struct global_load<AccessType,
}
};
template <typename AccessType
>
struct global_load<AccessType,

View File

@ -150,6 +150,42 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, El
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Specifies internal data type for computation
struct SPFormatType {
enum Kind {
Thread
};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Matrix multiply-add operation
template <
/// Size of the matrix product (concept: GemmShape)
typename Shape_,
/// Number of threads participating
int kThreads_,
/// Data type of A elements
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Data type of B elements
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Inner product operator
typename Operator,
/// Specifies meta data format
SPFormatType::Kind SPFormat = SPFormatType::Thread
>
struct SparseMma;
} // namespace arch
} // namespace cutlass
@ -165,4 +201,5 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, El
#include "cutlass/arch/mma_sm70.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/arch/sp_mma_sm80.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -53,6 +53,7 @@ template <
struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
CUTLASS_HOST_DEVICE
void operator()(
@ -79,6 +80,7 @@ template <
struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
CUTLASS_HOST_DEVICE
void operator()(
@ -106,6 +108,7 @@ template <
struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
CUTLASS_HOST_DEVICE
void operator()(
@ -142,6 +145,7 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
CUTLASS_HOST_DEVICE
void operator()(
@ -181,6 +185,7 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
CUTLASS_HOST_DEVICE
void operator()(
@ -218,6 +223,7 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
CUTLASS_HOST_DEVICE
void operator()(
@ -255,6 +261,7 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
CUTLASS_HOST_DEVICE
void operator()(
@ -292,6 +299,7 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
CUTLASS_HOST_DEVICE
void operator()(
@ -327,6 +335,7 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAddComplex;
CUTLASS_HOST_DEVICE
void operator()(
@ -355,7 +364,8 @@ template <
struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 1>;
using Operator = OpMultiplyAdd;
CUTLASS_HOST_DEVICE
void operator()(
Array<float, 1> &d,

View File

@ -55,6 +55,7 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<2, 1, 1>;
using Operator = OpMultiplyAdd;
CUTLASS_HOST_DEVICE
void operator()(
@ -99,6 +100,7 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 2, 1>;
using Operator = OpMultiplyAdd;
CUTLASS_HOST_DEVICE
void operator()(
@ -143,6 +145,7 @@ struct Mma <
OpMultiplyAdd> {
using Shape = gemm::GemmShape<2, 2, 1>;
using Operator = OpMultiplyAdd;
CUTLASS_HOST_DEVICE
void operator()(
@ -196,7 +199,8 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<2, 2, 1>;
using Operator = OpMultiplyAdd;
CUTLASS_HOST_DEVICE
void operator()(
Array<half_t, 4> &d,

View File

@ -51,7 +51,8 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 4>;
using Operator = OpMultiplyAdd;
CUTLASS_HOST_DEVICE
void operator()(
Array<int, 1> &d,
@ -98,6 +99,7 @@ struct Mma<
OpMultiplyAdd> {
using Shape = gemm::GemmShape<1, 1, 2>;
using Operator = OpMultiplyAdd;
CUTLASS_HOST_DEVICE
void operator()(

View File

@ -723,7 +723,6 @@ struct Mma<
}
};
////////////////////////////////////////////////////////////////////////////////
//
// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE

View File

@ -85,7 +85,7 @@ Array<T, N> mac(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c
Array<T, N> d;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
d[i] = a[i] * b[i] + c;
d[i] = a[i] * b[i] + c[i];
}
return d;
}

File diff suppressed because it is too large Load Diff

View File

@ -487,6 +487,46 @@ public:
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Element>
CUTLASS_HOST_DEVICE
Array<Element, 1> make_Array(Element x) {
Array<Element, 1> m;
m[0] = x;
return m;
}
template <typename Element>
CUTLASS_HOST_DEVICE
Array<Element, 2> make_Array(Element x, Element y) {
Array<Element, 2> m;
m[0] = x;
m[1] = y;
return m;
}
template <typename Element>
CUTLASS_HOST_DEVICE
Array<Element, 3> make_Array(Element x, Element y, Element z) {
Array<Element, 3> m;
m[0] = x;
m[1] = y;
m[2] = z;
return m;
}
template <typename Element>
CUTLASS_HOST_DEVICE
Array<Element, 4> make_Array(Element x, Element y, Element z, Element w) {
Array<Element, 4> m;
m[0] = x;
m[1] = y;
m[2] = z;
m[3] = w;
return m;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -65,7 +65,7 @@ struct alignas(2) bfloat16_t {
/// Default constructor
CUTLASS_HOST_DEVICE
bfloat16_t() { }
bfloat16_t() : storage(0) { }
/// Floating-point conversion - round toward nearest
CUTLASS_HOST_DEVICE

View File

@ -187,10 +187,12 @@ class complex
/// Division
template <typename A>
CUTLASS_HOST_DEVICE complex<T> operator/(complex<A> const &rhs) const {
T d = (rhs.real() * (rhs) + rhs.imag() * rhs.imag());
T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag());
return complex<T>((this->real() * (rhs) + this->imag() * rhs.imag()) / d,
(this->imag() * (rhs)-this->real() * rhs.imag()) / d);
return complex<T>(
(real() * rhs.real() + imag() * rhs.imag()) / d,
(imag() * rhs.real() - real() * rhs.imag()) / d
);
}
/// Scalar Division

1233
include/cutlass/constants.h Normal file

File diff suppressed because it is too large Load Diff

View File

@ -439,6 +439,12 @@ Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
return Coord<4>(values);
}
/// Helper to make a 5-element coordinate
CUTLASS_HOST_DEVICE
Coord<5> make_Coord(int _0, int _1, int _2, int _3, int _4) {
int values[5] = {_0, _1, _2, _3, _4};
return Coord<5>(values);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -31,18 +31,43 @@
#include <iostream>
#include <typeinfo>
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/tensor_view.h"
#include "cutlass/gemm/gemm.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Output operator for CUDA built-in dim3 type
inline std::ostream &operator<<(std::ostream &out, dim3 d) {
return out << d.x << ", " << d.y << ", " << d.z;
}
/// Output operator for CUDA built-in error type
inline std::ostream &operator<<(std::ostream &out, cudaError_t error) {
return out << cudaGetErrorString(error);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
// stream operators for cutlass namespace //
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Element, int Rank>
inline
std::ostream& operator<<(std::ostream& out, Array<Element, Rank> const& v) {
for (int i = 0; i < Rank; ++i) {
out << (i ? ", " : "") << v[i];
}
return out;
}
template <int Rank>
inline
std::ostream& operator<<(std::ostream& out, Coord<Rank> const& coord) {
@ -115,7 +140,7 @@ inline std::ostream &operator<<(std::ostream &out, ScalarIO<uint8_t> const &scal
/// Default printing to ostream for MatrixShape
template <int Row, int Column>
inline
std::ostream & operator<<(std::ostream &out, cutlass::MatrixShape<Row, Column> const &matrix_shape) {
std::ostream & operator<<(std::ostream &out, MatrixShape<Row, Column> const &matrix_shape) {
out << "cutlass::MatrixShape::(kRow, kColumn) {"
<< cutlass::MatrixShape<Row,Column>::kRow <<","
<< cutlass::MatrixShape<Row,Column>::kColumn <<"}";
@ -130,7 +155,7 @@ namespace gemm {
/// Default printing to ostream for GemmShape
template <int M, int N, int K>
inline
std::ostream & operator<<(std::ostream &out, cutlass::gemm::GemmShape<M,N,K> const &gemm_shape) {
std::ostream & operator<<(std::ostream &out, GemmShape<M,N,K> const &gemm_shape) {
out << "cutlass::GemmShape::(kM, kN, kK) {"
<< cutlass::gemm::GemmShape<M,N,K>::kM <<","
<< cutlass::gemm::GemmShape<M,N,K>::kN <<","
@ -150,7 +175,7 @@ namespace layout {
/// Default printing to ostream for PitchLinearShape
template < int Contiguous, int Strided>
inline
std::ostream & operator<<(std::ostream &out, cutlass::layout::PitchLinearShape<Contiguous, Strided> const &pitch_linear_shape) {
std::ostream & operator<<(std::ostream &out, PitchLinearShape<Contiguous, Strided> const &pitch_linear_shape) {
out << "cutlass::layout::PitchLinearShape::(kContiguous, kStrided) {"
<< cutlass::layout::PitchLinearShape<Contiguous,Strided>::kContiguous <<","
<< cutlass::layout::PitchLinearShape<Contiguous,Strided>::kStrided <<"}";

View File

@ -125,13 +125,6 @@ static char const* cutlassGetStatusString(cutlass::Status status) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct Debug {
typename T::X x;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
static const int NUM_THREADS_PER_WARP = 32;
static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2;
static const int NUM_THREADS_PER_QUAD = 4;
@ -143,7 +136,7 @@ static const int NUM_THREADS_PER_QUAD_PAIR = NUM_THREADS_PER_QUAD * 2;
CUTLASS_DEVICE
int LaneId() {
int ret;
asm ("mov.u32 %0, %%laneid;" : "=r"(ret));
asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : );
return ret;
}
@ -151,7 +144,7 @@ int LaneId() {
CUTLASS_DEVICE
int SmId() {
int ret;
asm ("mov.u32 %0, %%smid;" : "=r"(ret));
asm ("mov.u32 %0, %%smid;" : "=r"(ret) : );
return ret;
}

View File

@ -31,9 +31,8 @@
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/constants.h"
#include "cutlass/complex.h"
#include "cutlass/array.h"
#include "cutlass/half.h"
#include "cutlass/functional.h"
@ -108,6 +107,40 @@ struct Sigmoid<Array<T, N> > {
}
};
// GELU operator
template <typename T>
struct GELU {
CUTLASS_HOST_DEVICE
T operator()(T const &scalar) const {
return T(cutlass::constants::half<T>() * scalar *
(cutlass::constants::one<T>() + erff( scalar / cutlass::constants::root_two<T>() )));
}
};
template <>
struct GELU<float> {
CUTLASS_HOST_DEVICE
float operator()(float const &scalar) const {
return cutlass::constants::half<float>() * scalar *
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
}
};
template <typename T, int N>
struct GELU<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> y;
GELU<T> gelu_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < int(rhs.size()); ++i) {
y[i] = gelu_op(rhs[i]);
}
return y;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -95,6 +95,13 @@ public:
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
@ -102,6 +109,13 @@ public:
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
}
};
private:

View File

@ -236,6 +236,18 @@ public:
using ElementAccumulator = int;
using ElementCompute = float;
static_assert(
platform::is_same<ElementOutput, int32_t>::value ||
platform::is_same<ElementOutput, uint32_t>::value ||
platform::is_same<ElementOutput, int16_t>::value ||
platform::is_same<ElementOutput, uint16_t>::value ||
platform::is_same<ElementOutput, int8_t>::value ||
platform::is_same<ElementOutput, uint8_t>::value ||
platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
platform::is_same<ElementOutput, cutlass::uint1b_t>::value,
"This elementwise op expects the output to be int.");
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
@ -392,8 +404,9 @@ public:
///
/// D = alpha * accumulator + beta * source + uniform
///
/// Note: The below method only works for small k dimensions. The default
/// approach is above
/// Note: The below method only when problem_size_K <= 256 for signed int8 gemm
/// or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
/// above.
/// TODO: Add logic to fallback to the default approach
template <
/// Data type used to load and store< tensors
@ -408,6 +421,18 @@ class FastLinearCombinationClamp {
using ElementAccumulator = int;
using ElementCompute = float;
static_assert(
platform::is_same<ElementOutput, int32_t>::value ||
platform::is_same<ElementOutput, uint32_t>::value ||
platform::is_same<ElementOutput, int16_t>::value ||
platform::is_same<ElementOutput, uint16_t>::value ||
platform::is_same<ElementOutput, int8_t>::value ||
platform::is_same<ElementOutput, uint8_t>::value ||
platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
platform::is_same<ElementOutput, cutlass::uint1b_t>::value,
"This elementwise op expects the output to be int.");
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;

View File

@ -0,0 +1,206 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination with GELU operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationGELU {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationGELU(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return beta_ != ElementCompute(0);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition) {
if (k_partition) {
beta_ = ElementCompute(1);
}
}
/// Computes: D = gelu( alpha * accumulator + beta * source )
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
GELU<ComputeFragment> gelu;
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
intermediate = gelu(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
/// Computes: D = gelu( alpha * accumulator )
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_accumulator;
GELU<ComputeFragment> gelu;
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
intermediate = gelu(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass

View File

@ -78,7 +78,8 @@ template <
int ElementsPerAccess,
/// Multiply-add operator
/// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
typename Operator_ = arch::OpMultiplyAddComplex>
typename Operator_ = arch::OpMultiplyAddComplex
>
struct DefaultEpilogueComplexTensorOp {
using Shape = Shape_;
@ -87,7 +88,6 @@ struct DefaultEpilogueComplexTensorOp {
using OutputOp = OutputOp_;
static int const kElementsPerAccess = ElementsPerAccess;
using Operator = Operator_;
using ElementOutput = typename OutputOp::ElementOutput;
using LayoutC = typename WarpMmaTensorOp::LayoutC;
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
@ -164,7 +164,8 @@ template <
>
struct DefaultEpilogueComplexTensorOp <Shape_, WarpMmaTensorOp_, PartitionsK,
OutputOp_, ElementsPerAccess,
arch::OpMultiplyAddGaussianComplex> {
arch::OpMultiplyAddGaussianComplex
> {
using Shape = Shape_;
using WarpMmaTensorOp = WarpMmaTensorOp_;
@ -172,7 +173,6 @@ struct DefaultEpilogueComplexTensorOp <Shape_, WarpMmaTensorOp_, PartitionsK,
using OutputOp = OutputOp_;
static int const kElementsPerAccess = ElementsPerAccess;
using Operator = arch::OpMultiplyAddGaussianComplex;
using ElementOutput = typename OutputOp::ElementOutput;
using LayoutC = typename WarpMmaTensorOp::LayoutC;
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;

View File

@ -251,7 +251,6 @@ struct DefaultEpilogueTensorOp {
static int const kPartitionsK = PartitionsK;
using OutputOp = OutputOp_;
static int const kElementsPerAccess = ElementsPerAccess;
using ElementOutput = typename OutputOp::ElementOutput;
using LayoutC = typename WarpMmaTensorOp::LayoutC;
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;

View File

@ -68,7 +68,7 @@ struct DefaultThreadMapSimt {
static_assert(
!(ThreadblockShape::kM % WarpShape::kM) &&
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
/// Number of warps
using WarpCount = gemm::GemmShape<

View File

@ -69,7 +69,7 @@ struct DefaultThreadMapTensorOp {
static_assert(
!(ThreadblockShape::kM % WarpShape::kM) &&
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
/// Number of warps
using WarpCount = gemm::GemmShape<
@ -119,7 +119,7 @@ struct DefaultInterleavedThreadMapTensorOp {
static int const kWarpSize = 32;
static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
!(ThreadblockShape::kM % WarpShape::kM),
!(ThreadblockShape::kN % WarpShape::kN),
"Divisibility");
/// Number of warps

View File

@ -88,7 +88,7 @@ struct DefaultThreadMapVoltaTensorOp<
static_assert(
!(ThreadblockShape::kM % WarpShape::kM) &&
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
/// Number of warps
using WarpCount = gemm::GemmShape<
@ -169,7 +169,7 @@ struct DefaultThreadMapVoltaTensorOp<
static_assert(
!(ThreadblockShape::kM % WarpShape::kM) &&
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
/// Number of warps
using WarpCount = gemm::GemmShape<

View File

@ -71,7 +71,7 @@ struct DefaultThreadMapWmmaTensorOp {
static_assert(
!(ThreadblockShape::kM % WarpShape::kM) &&
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
/// Number of warps
using WarpCount = gemm::GemmShape<

View File

@ -104,7 +104,6 @@ public:
using OutputOp = OutputOp_;
using Padding = Padding_;
/// Output layout is always row-major
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
@ -164,7 +163,7 @@ public:
typename Base::SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
int lane_idx ///< Id of thread within warp
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
shared_load_iterator_(shared_storage.reference(), thread_idx) { }
@ -192,7 +191,8 @@ private:
void compute_source_not_needed_(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators) { ///< Complete warp-level accumulator tile
AccumulatorTile const &accumulators ///< Complete warp-level accumulator tile
) {
//
// Iterator over warp-level accumulator fragment
@ -259,9 +259,9 @@ private:
// Store the final result
//
destination_iterator.store(output_fragment);
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
@ -272,7 +272,8 @@ private:
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
) {
typename OutputTileIterator::Fragment source_fragment;

View File

@ -41,6 +41,7 @@
#include "cutlass/tensor_ref.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory.h"
////////////////////////////////////////////////////////////////////////////////
@ -54,7 +55,7 @@ namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load output tile from shared memory in epilogue.
/// Tile iterator used to load and store output tile from shared memory in epilogue.
///
/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator
///
@ -243,8 +244,8 @@ public:
TensorCoord extent,
int thread_idx,
TensorCoord threadblock_offset = TensorCoord()
):
params_(params) {
): params_(params)
{
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
@ -336,7 +337,8 @@ public:
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) {
load_with_byte_offset(frag, 0);
load_with_byte_offset(frag, 0);
}
/// Stores a fragment to memory
@ -397,7 +399,8 @@ public:
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) {
store_with_byte_offset(frag, 0);
store_with_byte_offset(frag, 0);
}
/// Advances to the next position to load or store

View File

@ -60,8 +60,8 @@ struct TensorOpPolicy<WarpShape, OperatorShape, layout::RowMajor> {
/// Number of operations
using OperatorCount = MatrixShape<
WarpShape::kM / OperatorShape::kM,
WarpShape::kN / OperatorShape::kN
(WarpShape::kM + OperatorShape::kM - 1) / OperatorShape::kM,
(WarpShape::kN + OperatorShape::kN - 1) / OperatorShape::kN
>;
//
@ -70,6 +70,8 @@ struct TensorOpPolicy<WarpShape, OperatorShape, layout::RowMajor> {
static int const kElementsPerAccess = 2;
static int const kRowsPerIteration = 8;
static bool const kDivisible =
!(WarpShape::kM % OperatorShape::kM) && !(WarpShape::kN % OperatorShape::kN);
//
// Derived quantities

View File

@ -29,6 +29,7 @@
#pragma once
#include "cutlass/array.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
@ -116,6 +117,9 @@ private:
/// Internal layout object
Layout layout_;
/// Thread offset
MatrixCoord thread_offset_;
public:
/// Default constructor
@ -129,12 +133,16 @@ public:
unsigned lane_id
):
pointer_(reinterpret_cast<AccessType *>(ref.data())),
layout_(ref.stride()[0] / Policy::kElementsPerAccess) {
layout_(ref.stride()[0] / Policy::kElementsPerAccess) {
int quad_id = (lane_id / Detail::kLanesInQuad);
int lane_in_quad = (lane_id % Detail::kLanesInQuad);
pointer_ += layout_({quad_id, lane_in_quad});
thread_offset_ = {
quad_id, lane_in_quad * Policy::kElementsPerAccess
};
pointer_ += layout_({thread_offset_.row(), thread_offset_.column() / Policy::kElementsPerAccess});
}
/// Adds a pointer offset
@ -148,9 +156,16 @@ public:
CUTLASS_HOST_DEVICE
TileIteratorTensorOp & add_tile_offset(TensorCoord const &tile_offset) {
pointer_ += layout_({
MatrixCoord coord_offset(
tile_offset.row() * Shape::kRow,
(tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess)
tile_offset.column() * Shape::kColumn
);
thread_offset_ += coord_offset;
pointer_ += layout_({
coord_offset.row(),
coord_offset.column() / Policy::kElementsPerAccess
});
return *this;
@ -198,6 +213,235 @@ public:
void load(Fragment &frag) const {
load_with_pointer_offset(frag, 0);
}
CUTLASS_HOST_DEVICE
TileIteratorTensorOp & operator++() {
return add_tile_offset({1, 0});
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Template for reading and writing tiles of accumulators to shared memory
template <
typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape)
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
typename Element_, ///< data type of element to be written
typename Layout_
>
class TileIteratorTensorOpCanonical {
public:
using WarpShape = WarpShape_;
using OperatorShape = OperatorShape_;
using Element = Element_;
using Layout = Layout_;
using TensorRef = TensorRef<Element, Layout>; ///< Tensor Reference object
using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor
using Index = typename TensorRef::Index;
using LongIndex = typename TensorRef::LongIndex;
using Policy = TensorOpPolicy<WarpShape, OperatorShape, Layout>;
static int const kAccessSize = 1;
static int const kAccessCount = Policy::kElementsPerAccess / kAccessSize;
/// Shape of the tile in memory
using Shape = MatrixShape<
Policy::kRowsPerIteration,
WarpShape::kN
>;
/// This is the fragment size produced by one access of the iterator.
using Fragment = Array<
Element,
Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>;
/// This is the complete warp-level accumulator tile.
//using AccumulatorTile = typename Operator::FragmentC;
/// Number of times this iterator can be incremented
static int const kIterations = Policy::kIterations;
// Internal constants
struct Detail {
static int const kLanesInQuad = 4;
};
/// Padding quantity
using Padding = MatrixShape<
0,
Detail::kLanesInQuad * Policy::kElementsPerAccess>;
private:
/// Storage type for accessing memory
using AccessType = AlignedArray<Element, kAccessSize>;
//
// Data members
//
/// Internal pointer to memory
AccessType *pointer_;
/// Internal layout object
Layout layout_;
/// Guard to indicate whether the shape is divisible
bool divisible_;
/// Extent of the output tensor
MatrixCoord extent_;
/// Thread offset
MatrixCoord thread_offset_;
public:
/// Default constructor
CUTLASS_HOST_DEVICE
TileIteratorTensorOpCanonical(): pointer_(nullptr) { }
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
TileIteratorTensorOpCanonical(
TensorRef const &ref,
unsigned lane_id
):
pointer_(reinterpret_cast<AccessType *>(ref.data())),
layout_(ref.stride()[0]),
divisible_(true),
extent_(WarpShape::kM, WarpShape::kN) {
int quad_id = (lane_id / Detail::kLanesInQuad);
int lane_in_quad = (lane_id % Detail::kLanesInQuad);
thread_offset_ = {
quad_id, lane_in_quad * Policy::kElementsPerAccess
};
pointer_ += layout_({thread_offset_.row(), thread_offset_.column()});
}
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
TileIteratorTensorOpCanonical(
TensorRef const &ref,
TensorCoord const &extent,
unsigned lane_id
):
pointer_(reinterpret_cast<AccessType *>(ref.data())),
layout_(ref.stride()[0]),
divisible_(false),
extent_(extent) {
int quad_id = (lane_id / Detail::kLanesInQuad);
int lane_in_quad = (lane_id % Detail::kLanesInQuad);
thread_offset_ = {
quad_id, lane_in_quad * Policy::kElementsPerAccess
};
pointer_ += layout_({thread_offset_.row(), thread_offset_.column()});
}
/// Adds a pointer offset
CUTLASS_HOST_DEVICE
TileIteratorTensorOpCanonical & add_pointer_offset(Index pointer_offset) {
pointer_ += pointer_offset;
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_HOST_DEVICE
TileIteratorTensorOpCanonical & add_tile_offset(TensorCoord const &tile_offset) {
MatrixCoord coord_offset(
tile_offset.row() * Shape::kRow,
tile_offset.column() * Shape::kColumn
);
thread_offset_ += coord_offset;
pointer_ += layout_({
coord_offset.row(),
coord_offset.column()
});
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_HOST_DEVICE
TileIteratorTensorOpCanonical & operator+=(TensorCoord const &tile_offset) {
add_tile_offset(tile_offset);
return *this;
}
/// Store
CUTLASS_HOST_DEVICE
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
CUTLASS_PRAGMA_UNROLL
for (int a = 0; a < kAccessCount; ++a) {
int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a;
int frag_idx = n * kAccessCount + a;
int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a;
if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) {
pointer_[ptr_idx] = frag_ptr[frag_idx];
}
}
}
}
/// Store
CUTLASS_HOST_DEVICE
void store(Fragment const &frag) {
store_with_pointer_offset(frag, 0);
}
/// Load
CUTLASS_HOST_DEVICE
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
CUTLASS_PRAGMA_UNROLL
for (int a = 0; a < kAccessCount; ++a) {
int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a;
int frag_idx = n * kAccessCount + a;
int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a;
if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) {
frag_ptr[frag_idx] = pointer_[ptr_idx];
}
}
}
}
/// Load
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
load_with_pointer_offset(frag, 0);
}
CUTLASS_HOST_DEVICE
TileIteratorTensorOpCanonical & operator++() {
return add_tile_offset({1, 0});
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -33,6 +33,7 @@
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/epilogue/warp/tensor_op_policy.h"
#include "cutlass/epilogue/warp/volta_tensor_op_policy.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -29,6 +29,7 @@
#include <cuda/std/cstdint>
#else
#include <cstdint>
#include <cmath>
#endif
#include "cutlass/cutlass.h"
@ -40,6 +41,8 @@
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/******************************************************************************
* Static math utilities
******************************************************************************/
@ -136,6 +139,19 @@ CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
return temp ? (a / temp * b) : 0;
}
/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b
CUTLASS_HOST_DEVICE
constexpr int round_up(int a, int b) {
return ((a + b - 1) / b) * b;
}
/// Returns the ceiling of (a / b)
CUTLASS_HOST_DEVICE
constexpr int ceil_div(int a, int b) {
return (a + b - 1) / b;
}
/**
* log2 computation, what's the
* difference between the below codes and
@ -189,7 +205,6 @@ void fast_divmod(int& quo, int& rem, int src, int div, unsigned int mul, unsigne
// The remainder.
rem = src - (quo * div);
}
// For long int input
@ -206,17 +221,56 @@ void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul,
rem = src - (quo * div);
}
/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b
CUTLASS_HOST_DEVICE
int round_up(int a, int b) {
return ((a + b - 1) / b) * b;
}
/// Object to encapsulate the fast division+modulus operation.
///
/// This object precomputes two values used to accelerate the computation and is best used
/// when the divisor is a grid-invariant. In this case, it may be computed in host code and
/// marshalled along other kernel arguments using the 'Params' pattern.
///
/// Example:
///
///
/// int quotient, remainder, dividend, divisor;
///
/// FastDivmod divmod(divisor);
///
/// divmod(quotient, remainder, dividend);
///
/// // quotient = (dividend / divisor)
/// // remainder = (dividend % divisor)
///
struct FastDivmod {
/// Returns the ceiling of (a / b)
CUTLASS_HOST_DEVICE
int ceil_div(int a, int b) {
return (a + b - 1) / b;
}
int divisor;
unsigned int multiplier;
unsigned int shift_right;
/// Construct the FastDivmod object, in host code ideally.
///
/// This precomputes some values based on the divisor and is computationally expensive.
CUTLASS_HOST_DEVICE
FastDivmod(): divisor(0), multiplier(0), shift_right(0) { }
CUTLASS_HOST_DEVICE
FastDivmod(int divisor_): divisor(divisor_) {
find_divisor(multiplier, shift_right, divisor);
}
/// Computes integer division and modulus using precomputed values. This is computationally
/// inexpensive.
CUTLASS_HOST_DEVICE
void operator()(int &quotient, int &remainder, int dividend) const {
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
}
/// Computes integer division and modulus using precomputed values. This is computationally
/// inexpensive.
CUTLASS_HOST_DEVICE
void operator()(int &quotient, int64_t &remainder, int64_t dividend) const {
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
}
};
/******************************************************************************
* Min/Max
@ -242,4 +296,117 @@ constexpr int const_max(int a, int b) {
return (b > a ? b : a);
}
CUTLASS_HOST_DEVICE
float fast_cos(float theta) {
#if defined(__CUDA_ARCH__)
return ::cosf(theta);
#else
return std::cos(theta);
#endif
}
CUTLASS_HOST_DEVICE
double fast_cos(double theta) {
#if defined(__CUDA_ARCH__)
return ::cos(theta);
#else
return std::cos(theta);
#endif
}
CUTLASS_HOST_DEVICE
float fast_sin(float theta) {
#if defined(__CUDA_ARCH__)
return ::sinf(theta);
#else
return std::sin(theta);
#endif
}
CUTLASS_HOST_DEVICE
double fast_sin(double theta) {
#if defined(__CUDA_ARCH__)
return ::sin(theta);
#else
return std::sin(theta);
#endif
}
CUTLASS_HOST_DEVICE
float fast_acos(float theta) {
#if defined(__CUDA_ARCH__)
return ::acosf(theta);
#else
return std::acos(theta);
#endif
}
CUTLASS_HOST_DEVICE
double fast_acos(double theta) {
#if defined(__CUDA_ARCH__)
return ::acos(theta);
#else
return std::acos(theta);
#endif
}
CUTLASS_HOST_DEVICE
float fast_asin(float theta) {
#if defined(__CUDA_ARCH__)
return ::asinf(theta);
#else
return std::asin(theta);
#endif
}
CUTLASS_HOST_DEVICE
double fast_asin(double theta) {
#if defined(__CUDA_ARCH__)
return ::asin(theta);
#else
return std::asin(theta);
#endif
}
CUTLASS_HOST_DEVICE
float fast_sqrt(float theta) {
#if defined(__CUDA_ARCH__)
return ::sqrtf(theta);
#else
return std::sqrt(theta);
#endif
}
CUTLASS_HOST_DEVICE
double fast_sqrt(double theta) {
#if defined(__CUDA_ARCH__)
return ::sqrt(theta);
#else
return std::sqrt(theta);
#endif
}
CUTLASS_HOST_DEVICE
float fast_log(float x) {
#if defined(__CUDA_ARCH__)
return ::logf(x);
#else
return std::log(x);
#endif
}
CUTLASS_HOST_DEVICE
double fast_log(double x) {
#if defined(__CUDA_ARCH__)
return ::log(x);
#else
return std::log(x);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -32,9 +32,7 @@
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/complex.h"
#include "cutlass/array.h"
#include "cutlass/half.h"
@ -69,6 +67,82 @@ struct multiplies {
}
};
/// Squares with optional conversion
template <typename T, typename Output = T>
struct square {
CUTLASS_HOST_DEVICE
Output operator()(T lhs) const {
multiplies<Output> mul_op;
Output y = Output(lhs);
return mul_op(y, y);
}
};
/// Returns the magnitude squared of an element.
template <typename T, typename Output = T>
struct magnitude_squared {
CUTLASS_HOST_DEVICE
Output operator()(T lhs) const {
multiplies<Output> mul_op;
Output y = Output(lhs);
return mul_op(y, y);
}
};
/// Squares with optional conversion
template <typename T, typename Output>
struct magnitude_squared<complex<T>, Output> {
CUTLASS_HOST_DEVICE
Output operator()(complex<T> lhs) const {
multiplies<Output> mul_op;
Output y_r = Output(lhs.real());
Output y_i = Output(lhs.imag());
return mul_op(y_r, y_r) + mul_op(y_i, y_i);
}
};
/// Computes the square of a difference with optional conversion
template <typename T, typename Output = T>
struct square_difference {
CUTLASS_HOST_DEVICE
Output operator()(T lhs, T rhs) const {
multiplies<Output> mul_op;
Output y = Output(lhs) - Output(rhs);
return mul_op(y, y);
}
};
/// Computes the square of a difference with optional conversion
template <typename T, typename Output = T>
struct magnitude_squared_difference {
CUTLASS_HOST_DEVICE
Output operator()(T lhs, T rhs) const {
multiplies<Output> mul_op;
Output y = Output(lhs) - Output(rhs);
return mul_op(y, y);
}
};
/// Computes the square of a difference with optional conversion
template <typename T, typename Output>
struct magnitude_squared_difference<complex<T>, Output> {
CUTLASS_HOST_DEVICE
Output operator()(complex<T> lhs, complex<T> rhs) const {
multiplies<Output> mul_op;
Output y_r = Output(lhs.real()) - Output(rhs.real());
Output y_i = Output(lhs.imag()) - Output(rhs.imag());
return mul_op(y_r, y_r) + mul_op(y_i, y_i);
}
};
template <typename T>
struct divides {
CUTLASS_HOST_DEVICE

View File

@ -0,0 +1,517 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/sparse_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_sparse.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may
be invoked from host code.
The contributions of this class are:
1. At compile time, it maps data types and high-level structural parameters onto
specific CUTLASS components.
2. At runtime, it maps logical arguments to GEMM problems to kernel parameters.
3. At runtime, it launches kernels on the device.
The intent is to provide a convenient mechanism for interacting with most plausible GEMM
configurations for each supported architecture. Consequently, not all parameters are exposed
to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy
are selected to tradeoff simplicity of the interface with flexibility. We expect
most configurations to be specified at this level. Applications with more exotic requirements
may construct their kernels of interest using CUTLASS components at the threadblock, warp,
and thread levels of abstraction.
CUTLASS exposes computations using the functor design pattern in which objects compose some
internal state with an overloaded function call operator. This enables decoupling of
initialization from execution, possibly reducing overhead during steady state phases of
application execution.
CUTLASS device-level operators expose an Arguments structure encompassing each logical
input to the computation. This is distinct from the kernel-level Params structure pattern
which contains application-specific precomputed state needed by the device code.
Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN
is as follows:
//
// Instantiate the CUTLASS GEMM operator.
//
cutlass::gemm::device::Gemm<
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::ColumnMajor
> gemm_op;
//
// Launch the GEMM operation on the device
//
cutlass::Status status = gemm_op({
{m, n, k}, // GemmCoord problem_size,
{A, lda}, // TensorRef<float, layout::ColumnMajor> ref_A,
{B, ldb}, // TensorRef<float, layout::ColumnMajor> ref_B,
{C, ldc}, // TensorRef<float, layout::ColumnMajor> ref_C,
{D, ldd}, // TensorRef<float, layout::ColumnMajor> ref_D,
{alpha, beta} // EpilogueOutputOp::Params epilogue_op_params
});
A simplified view of the template is listed below.
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages
>
class Gemm;
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ =
typename threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Whether Beta is zero or not
bool IsBetaZero = false>
class SparseGemm {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
static bool const kSplitKSerial = SplitKSerial;
static bool const kIsBetaZero = IsBetaZero;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Define the kernel
using GemmKernel = typename kernel::DefaultSparseGemm<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
kStages,
kSplitKSerial,
Operator,
kIsBetaZero
>::GemmKernel;
using ElementE = typename GemmKernel::ElementE;
using LayoutE = typename GemmKernel::LayoutE;
static int const kAlignmentE = 128 / sizeof_bits<ElementE>::value;
static int const kSparse = GemmKernel::kSparse;
static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits;
static int const kElementsPerElementE = GemmKernel::kElementsPerElementE;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size;
TensorRef<ElementA const, LayoutA> ref_A;
TensorRef<ElementB const, LayoutB> ref_B;
TensorRef<ElementC const, LayoutC> ref_C;
TensorRef<ElementC, LayoutC> ref_D;
TensorRef<ElementE const, LayoutE> ref_E;
typename EpilogueOutputOp::Params epilogue;
int split_k_slices;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_,
TensorRef<ElementA const, LayoutA> ref_A_,
TensorRef<ElementB const, LayoutB> ref_B_,
TensorRef<ElementC const, LayoutC> ref_C_,
TensorRef<ElementC, LayoutC> ref_D_,
TensorRef<ElementE, LayoutE> ref_E_,
typename EpilogueOutputOp::Params epilogue_ =
typename EpilogueOutputOp::Params(),
int split_k_slices = 1
):
problem_size(problem_size_),
ref_A(ref_A_),
ref_B(ref_B_),
ref_C(ref_C_),
ref_D(ref_D_),
ref_E(ref_E_),
epilogue(epilogue_),
split_k_slices(split_k_slices) {
}
};
private:
/// Kernel parameters object
typename GemmKernel::Params params_;
public:
/// Constructs the GEMM.
SparseGemm() { }
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
if (!kSplitKSerial && args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
Status status = GemmKernel::can_implement(
args.problem_size,
args.ref_A.non_const_ref(),
args.ref_B.non_const_ref(),
args.ref_C.non_const_ref(),
args.ref_D,
args.ref_E.non_const_ref()
);
if (status != Status::kSuccess) {
return status;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
if (kSplitKSerial && args.split_k_slices > 1) {
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
return bytes;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
if (kSplitKSerial) {
if (args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
size_t bytes = get_workspace_size(args);
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
else {
if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
}
// Initialize the Params structure
params_ = typename GemmKernel::Params{
args.problem_size,
grid_shape,
args.ref_A.non_const_ref(),
args.ref_B.non_const_ref(),
args.ref_C.non_const_ref(),
args.ref_D,
args.ref_E.non_const_ref(),
args.epilogue,
static_cast<int *>(workspace)
};
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
if (kSplitKSerial && args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
}
params_.ref_A.reset(args.ref_A.non_const_ref().data());
params_.ref_B.reset(args.ref_B.non_const_ref().data());
params_.ref_C.reset(args.ref_C.non_const_ref().data());
params_.ref_D.reset(args.ref_D.data());
params_.ref_E.reset(args.ref_E.non_const_ref().data());
params_.output_op = args.epilogue;
params_.semaphore = static_cast<int *>(workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
cudaError_t result;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -30,6 +30,8 @@
#pragma once
#include <limits>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
@ -42,6 +44,8 @@
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -121,13 +125,30 @@ public:
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
grid.z <= std::numeric_limits<uint16_t>::max())) {
return Status::kErrorInvalidProblem;
}
return GemmKernel::can_implement(args);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()");
size_t workspace_bytes = 0;
// Determine grid shape
@ -151,28 +172,41 @@ public:
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
return workspace_bytes;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()");
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(
" grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return threadblock_swizzle.get_grid_shape(grid_tiled_shape);
return result;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10)) {
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
@ -182,6 +216,7 @@ public:
smem_size);
if (result == cudaSuccess) {
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
}
@ -195,6 +230,11 @@ public:
0);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<< cudaGetErrorString(result));
return -1;
}
@ -216,27 +256,43 @@ public:
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
}
return std::min(max_active_blocks, smem_capacity / smem_size);
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
return occupancy;
}
CUTLASS_TRACE_HOST(" returning internal error");
return -1;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args);
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
if (workspace_bytes) {
if (!workspace) {
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
return Status::kErrorWorkspaceNull;
}
if (args.mode == GemmUniversalMode::kGemm) {
CUTLASS_TRACE_HOST(" clearing device workspace");
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
@ -262,6 +318,8 @@ public:
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBase()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace) {
@ -275,6 +333,7 @@ public:
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
ThreadblockSwizzle threadblock_swizzle;
@ -302,11 +361,19 @@ public:
}
}
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block
<< "), SMEM: " << smem_size << " bytes");
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.

View File

@ -96,7 +96,7 @@ struct GemmCoord : public Coord<3, int> {
/// Integer-valued index
typedef int Index;
/// Base type is a Coord of rank=4
/// Base type is a Coord of rank=3
typedef Coord<3, Index> Base;
/// GEMM M dimension - rows of the output C matrix
@ -274,7 +274,7 @@ struct BatchedGemmCoord : public Coord<4, int> {
/// GEMM K dimension - inner dimension of the GEMM problem
static int const kK = 2;
/// GEMM K dimension - inner dimension of the GEMM problem
/// GEMM Batch dimension - inner dimension of the GEMM problem
static int const kBatch = 3;
//

View File

@ -0,0 +1,187 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/wmma.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm.h"
#include "cutlass/gemm/kernel/sparse_gemm.h"
#include "cutlass/gemm/kernel/gemm_pipelined.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h"
#include "cutlass/gemm/threadblock/default_sparse_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h"
#endif //CUTLASS_ARCH_WMMA_ENABLED
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Beta is zero or not
bool IsBetaZero = false>
struct DefaultSparseGemm;
////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultSparseGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
Operator> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
EpilogueOutputOp::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::SparseGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -175,7 +175,8 @@ struct Gemm {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
@ -252,7 +253,8 @@ struct Gemm {
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
//assume identity swizzle
MatrixCoord threadblock_offset(

View File

@ -133,7 +133,8 @@ struct GemmArray {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
@ -207,7 +208,8 @@ struct GemmArray {
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
//assume identity swizzle
MatrixCoord threadblock_offset(

View File

@ -140,7 +140,8 @@ struct GemmBatched {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
@ -219,7 +220,8 @@ struct GemmBatched {
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
//assume identity swizzle
MatrixCoord threadblock_offset(

View File

@ -66,7 +66,7 @@ __global__ void GemmPipelined(
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset();
cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);
if (grid_tiled_shape.m() <= tb_tile_offset.m() ||
grid_tiled_shape.n() <= tb_tile_offset.n()) {
@ -131,7 +131,7 @@ __global__ void GemmPipelined(
warp_id,
lane_id);
tb_tile_offset = threadblock_swizzle.get_tile_offset();
tb_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);
//assume identity swizzle
MatrixCoord threadblock_offset(

View File

@ -419,7 +419,8 @@ public:
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
@ -549,7 +550,8 @@ public:
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
//assume identity swizzle
MatrixCoord threadblock_offset(

View File

@ -376,7 +376,8 @@ public:
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||

View File

@ -128,7 +128,8 @@ struct GemmSplitKParallel {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
@ -205,7 +206,8 @@ struct GemmSplitKParallel {
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
//assume identity swizzle
MatrixCoord threadblock_offset(

View File

@ -36,6 +36,8 @@
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -154,6 +156,7 @@ public:
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) {
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
/// Returns arguments for the transposed problem
@ -252,6 +255,7 @@ public:
batch_stride_D(args.batch_stride_D),
semaphore(static_cast<int *>(workspace)) {
CUTLASS_TRACE_HOST("GemmUniversal::Params::Params() - problem_size: " << problem_size);
}
CUTLASS_HOST_DEVICE
@ -264,9 +268,16 @@ public:
ptr_C = const_cast<void *>(args.ptr_C);
ptr_D = args.ptr_D;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_D = args.batch_stride_D;
output_op = args.epilogue;
semaphore = static_cast<int *>(workspace);
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
}
};
@ -289,6 +300,8 @@ public:
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size) {
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
@ -297,9 +310,12 @@ public:
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
@ -314,7 +330,8 @@ public:
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
@ -421,7 +438,8 @@ public:
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
//assume identity swizzle
MatrixCoord threadblock_offset(

View File

@ -0,0 +1,392 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
>
struct SparseGemm {
using Mma = Mma_;
using Epilogue = Epilogue_;
using OutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
static int const kSparse = Mma::kSparse;
static int const kMetaSizeInBits = Mma::kMetaSizeInBits;
static int const kMaxID2 = Mma::kMaxID2;
static int const kElementsPerElementE = Mma::kElementsPerElementE;
using ElementE = typename Mma::ElementE;
using LayoutE = typename Mma::LayoutE;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorB::TensorRef ref_B;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
typename Mma::IteratorE::Params params_E;
typename Mma::IteratorE::TensorRef ref_E;
typename OutputOp::Params output_op;
int *semaphore;
int gemm_k_iterations;
int gemm_k_size;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params(): semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
typename Mma::IteratorE::TensorRef ref_E,
typename OutputOp::Params output_op = typename OutputOp::Params(),
int *workspace = nullptr
):
problem_size(problem_size),
grid_tiled_shape(grid_tiled_shape),
params_A(ref_A.layout()),
ref_A(ref_A),
params_B(ref_B.layout()),
ref_B(ref_B),
params_C(ref_C.layout()),
ref_C(ref_C),
params_D(ref_D.layout()),
ref_D(ref_D),
params_E(ref_E.layout()),
ref_E(ref_E),
output_op(output_op) {
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
semaphore = workspace;
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
SparseGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
typename Mma::IteratorE::TensorRef ref_E) {
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
static int const kAlignmentE = Mma::IteratorE::AccessType::kElements;
if (!TensorRef_aligned(ref_A, kAlignmentA)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_B, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_C, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_D, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_E, kAlignmentE)) {
return Status::kErrorMisalignedOperand;
}
if ((problem_size.m() % kAlignmentA) || ((problem_size.k() / kSparse) % kAlignmentA) ||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC) ||
(problem_size.m() % kAlignmentE) || ((problem_size.k() / kSparse) % kAlignmentE)) {
return Status::kErrorMisalignedOperand;
}
// The k dimension has to be the multiple of the Threadblock k because out
// of bound meta data would be initialized to 0 by acync.zfill but 0 is not
// a valid meta data.
if (problem_size.k() % Mma::Shape::kK) {
return Status::kErrorMisalignedOperand;
}
// M dimension has to be multiple of 32 (sparse float) or 16 (sparse int)
// because of the row reordering of operand E
static int const kAlignmentM = (sizeof(ElementE) == 2) ? 32 : 16;
if (problem_size.m() % kAlignmentM) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
};
cutlass::MatrixCoord tb_offset_B{
threadblock_tile_offset.k() * params.gemm_k_size,
threadblock_tile_offset.n() * Mma::Shape::kN
};
cutlass::MatrixCoord tb_offset_E{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(
params.problem_size.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A, B, and E operands
typename Mma::IteratorA iterator_A(
params.params_A,
params.ref_A.data(),
{params.problem_size.m(), problem_size_k / kSparse},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
params.ref_B.data(),
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
typename Mma::IteratorE iterator_E(
params.params_E, params.ref_E.data(),
{params.problem_size.m(),
problem_size_k / kSparse / kElementsPerElementE},
thread_idx, tb_offset_E);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations > 0) {
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
}
//
// Epilogue
//
OutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
params.ref_C.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
params.ref_D.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
__threadfence();
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
__threadfence();
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -229,6 +229,16 @@ struct Mma<
/// C operand storage
using FragmentC = Array<ElementC, Shape::kMN>;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename MmaGeneric<
Shape,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
Operator>::MmaOp;
//
// Methods
//

View File

@ -977,6 +977,30 @@ struct Mma<
/// C operand storage
using FragmentC = Array<ElementC, Shape::kMN>;
static bool const a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value;
static bool const b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value;
static bool const c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value;
static bool const c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value;
static bool const m_mod2 = !(Shape::kM % 2);
static bool const n_mod2 = !(Shape::kN % 2);
static bool const k_mod2 = !(Shape::kK % 2);
// HFMA based MMA optimizations are of 2 types :
// 1. Inner product
// 2. Outer product
// It is chosen based on LayoutC (for outer product gemm) or
// Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms)
// If all fails, we choose the generic MMA
static bool const use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2);
static bool const use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2);
static bool const use_optimized = (use_outer_prod || use_inner_prod);
using ArchMmaOperator = typename platform::conditional< use_optimized,
detail::Mma_HFMA2<Shape, LayoutA, LayoutB, LayoutC, use_outer_prod>,
MmaGeneric <Shape, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator>
>::type;
//
// Methods
//
@ -989,30 +1013,8 @@ struct Mma<
FragmentB const & B,
FragmentC const & C) {
constexpr bool a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value;
constexpr bool b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value;
constexpr bool c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value;
constexpr bool c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value;
ArchMmaOperator mma;
constexpr bool m_mod2 = !(Shape::kM % 2);
constexpr bool n_mod2 = !(Shape::kN % 2);
constexpr bool k_mod2 = !(Shape::kK % 2);
// HFMA based MMA optimizations are of 2 types :
// 1. Inner product
// 2. Outer product
// It is chosen based on LayoutC (for outer product gemm) or
// Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms)
// If all fails, we choose the generic MMA
constexpr bool use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2);
constexpr bool use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2);
constexpr bool use_optimized = (use_outer_prod || use_inner_prod);
typename platform::conditional< use_optimized,
detail::Mma_HFMA2<Shape, LayoutA, LayoutB, LayoutC, use_outer_prod>,
MmaGeneric <Shape, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator>
>::type mma;
mma(D, A, B, C);
}
@ -1086,6 +1088,8 @@ struct Mma<
using FragmentB = Array<ElementB, Shape::kKN>;
using FragmentC = Array<ElementC, Shape::kMN>;
using ArchMmaOperator = typename TransposeMma::ArchMmaOperator;
CUTLASS_HOST_DEVICE
void operator()(
FragmentC & D,

View File

@ -93,6 +93,19 @@ struct Mma<
/// C operand storage
using FragmentC = Array<ElementC, Shape::kMN>;
/// Underlying matrix multiply operator (concept: arch::Mma)
// Use 1x1x4 IDP4A sequence for bulk of computation
using ArchMmaOperator = arch::Mma<
gemm::GemmShape<1,1,4>,
1,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
arch::OpMultiplyAdd>;
//
// Methods
//
@ -112,22 +125,11 @@ struct Mma<
D = C;
/// Use 1x1x4 IDP4A sequence for bulk of computation
using Mma = arch::Mma<
gemm::GemmShape<1,1,4>,
1,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
arch::OpMultiplyAdd>;
Mma mma;
ArchMmaOperator mma;
// Compute matrix product
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {
for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Shape::kN; ++n) {
@ -143,8 +145,8 @@ struct Mma<
mma(
tmp,
ptr_A[m * Shape::kK / Mma::Shape::kK + k],
ptr_B[n * Shape::kK / Mma::Shape::kK + k],
ptr_A[m * Shape::kK / ArchMmaOperator::Shape::kK + k],
ptr_B[n * Shape::kK / ArchMmaOperator::Shape::kK + k],
tmp);
d.at(mn) = reinterpret_cast<int32_t &>(tmp);
@ -206,6 +208,19 @@ struct Mma<
/// C operand storage
using FragmentC = Array<ElementC, Shape::kMN>;
/// Underlying matrix multiply operator (concept: arch::Mma)
/// Use 1x1x4 IDP4A sequence for bulk of computation
using ArchMmaOperator = arch::Mma<
gemm::GemmShape<1,1,4>,
1,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
arch::OpMultiplyAdd>;
//
// Methods
//
@ -224,25 +239,15 @@ struct Mma<
// Copy accumulators
D = C;
/// Use 1x1x4 IDP4A sequence for bulk of computation
using Mma = arch::Mma<
gemm::GemmShape<1,1,4>,
1,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
arch::OpMultiplyAdd>;
Mma mma;
/// Underlying matrix multiply operator
ArchMmaOperator mma;
Array<int8_t, 4> const *ptr_A = reinterpret_cast<Array<int8_t, 4> const *>(&A);
Array<int8_t, 4> const *ptr_B = reinterpret_cast<Array<int8_t, 4> const *>(&B);
// Compute matrix product
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {
for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Shape::kN; ++n) {

View File

@ -36,6 +36,7 @@
#include "cutlass/layout/matrix.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"

View File

@ -1,197 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data
layout of the global memory fragments, data types, and internal tile sizes.
Partial specializations for threadblock::Mma operations targeting TensorOp instructions.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/gemm/warp/mma_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: column-major
/// B: row-major
/// InstructionShape: 1-by-1-by-1
/// Operator: SIMT
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of B operand
typename ElementB_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Operation performed by GEMM
typename Operator_>
struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 1>, ElementA_,
layout::ColumnMajor, ElementB_, layout::RowMajor,
ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_,
> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using ElementA = ElementA_;
using LayoutA = layout::ColumnMajor;
using ElementB = ElementB_;
using LayoutB = layout::RowMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using OperatorClass = arch::OpClassSimt;
/// Number of warps present
using WarpCount = GemmShape<
Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
Shape::kK / WarpShape::kK
>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) &&
!(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
);
/// Number of threads per warp
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
//
// Shared memory layouts
//
/// Shared memory layout for A operand
using SmemLayoutA = layout::ColumnMajor;
/// Shared memory layout for B operand
using SmemLayoutB = layout::RowMajor;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<Shape::kM, Shape::kK>,
kThreads,
1
>;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileIterator<
MatrixShape<Shape::kM, Shape::kK>,
ElementA,
SmemLayoutA,
1,
IteratorThreadMapA
>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<Shape::kN, Shape::kK>,
kThreads,
1
>;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileIterator<
MatrixShape<Shape::kK, Shape::kN>,
ElementB,
SmemLayoutB,
0,
IteratorThreadMapB
>;
//
// Warp-level matrix multiply operator
//
// Define the warp-level tensor op
using WarpMma = cutlass::gemm::warp::MmaSimt<
WarpShape,
ElementA,
SmemLayoutA,
ElementB,
SmemLayoutB,
ElementC,
LayoutC,
warp::MmaSimtPolicy<
MatrixShape<4, 8>,
layout::RowMajorInterleaved<2>,
GemmShape<
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementB>::value,
1>
>
>
>;
/// Policy used to define MmaPipelined
using MmaPolicy = MmaPolicy<
WarpMma,
MatrixShape<0, 0>,
MatrixShape<0, 0>,
WarpCount::kK
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@ -1119,11 +1119,18 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, float,
/// Partial specialization:
///
/// A: column-major-interleave32
/// B: row-major-interleave32
/// A: column-major-interleave
/// B: row-major-interleave
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
///
/// Column/RowMajorInterleved<InterleavedK>(m, n) is mapped to Column/RowMajor(m
/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators
/// can be reused. The shared store iterator is the same as the crosswise shared
/// store iterator. So, the only thing we need to do is to swap the coordinates
/// (contiguous <=> strided) used by the global iterator and the shared store
/// iterator.
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)

View File

@ -1362,6 +1362,13 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
///
/// Column/RowMajorInterleved<InterleavedK>(m, n) is mapped to Column/RowMajor(m
/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators
/// can be reused. The shared store iterator is the same as the crosswise shared
/// store iterator. So, the only thing we need to do is to swap the coordinates
/// (contiguous <=> strided) used by the global iterator and the shared store
/// iterator.
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
@ -1608,7 +1615,7 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
kElementsPerAccess
>;
/// Transpose the ThreadMap of iterator A
/// Transpose the ThreadMap of iterator B
using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt<IteratorThreadMapB>;
/// Shared memory iterator to B operand
@ -1916,7 +1923,7 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
kElementsPerAccess
>;
/// Transpose the ThreadMap of iterator A
/// Transpose the ThreadMap of iterator B
using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt<IteratorThreadMapB>;
/// Shared memory iterator to B operand

View File

@ -0,0 +1,828 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines basic properties needed by CTA-level GEMMs assuming
expectations about data layout of the global memory fragments, data types,
and internal tile sizes.
Partial specializations for threadblock::Mma operations targeting sparse
TensorOp instructions.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
#include "cutlass/layout/tensor_op_multiplicand_sm80.h"
#include "cutlass/gemm/warp/mma_simt_policy.h"
#include "cutlass/gemm/warp/mma_simt.h"
#include "cutlass/gemm/warp/default_mma_sparse_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
#include "cutlass/gemm/threadblock/default_mma_core.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h"
#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h"
#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h"
#include "cutlass/gemm/threadblock/mma_sparse_multistage.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Template defininng default matrix multiply operators inferred from threadblock tile size,
/// global memory data layout, and target math instruction.
template <
/// Shape of threadblock-scoped matrix multiply operator
typename Shape,
/// Shape of warp-level matrix multiply operator
typename WarpShape,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape,
/// Element data type of A operand
typename ElementA,
/// Layout of operand A
typename LayoutA,
/// Element data type of B operand
typename ElementB,
/// Layout of operand B
typename LayoutB,
/// Data type of accumulator
typename ElementC,
/// Layout of accumulator
typename LayoutC,
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
typename OperatorClass,
/// Number of stages
int Stages,
/// Operation performed by MMA
typename Operator = typename platform::conditional<
(platform::is_same<OperatorClass,
cutlass::arch::OpClassTensorOp>::value) &&
(platform::is_same<ElementA, int8_t>::value ||
platform::is_same<ElementA, int4b_t>::value ||
platform::is_same<ElementA, uint8_t>::value ||
platform::is_same<ElementA, uint4b_t>::value),
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false
/// Cache operation of operand A
, cutlass::arch::CacheOperation::Kind CacheOpA =
cutlass::arch::CacheOperation::Global,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB =
cutlass::arch::CacheOperation::Global
>
struct DefaultSparseMmaCore;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: column-major
/// B: row-major
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of B operand
typename ElementB_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Number of stages
int Stages,
/// Operation performed by MMA
typename Operator_,
/// Cache operation of operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB>
struct DefaultSparseMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
layout::ColumnMajor, ElementB_, layout::RowMajor,
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
Operator_, false, CacheOpA, CacheOpB> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using ElementA = ElementA_;
using LayoutA = layout::ColumnMajor;
using ElementB = ElementB_;
using LayoutB = layout::RowMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
static int const kStages = Stages;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
static int const kSparse = 2;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
Shape::kK / WarpShape::kK>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
/// Number of threads per warp
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
/// Size of a threadblock-scoped access
static int const kAccessSizeInBits = 128;
/// Default Operator
using Operator = Operator_;
//
// Shared memory layouts
//
using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous<
sizeof_bits<ElementA>::value, int(128 / sizeof(ElementA))>;
// Shared memory layout
using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous<
sizeof_bits<ElementB>::value, int(128 / sizeof(ElementB))>;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kM, Shape::kK / kSparse>, kThreads,
layout::PitchLinearShape<8, 4>,
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM, Shape::kK / kSparse>, ElementA, SmemLayoutA, 1,
IteratorThreadMapA>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kN, Shape::kK>, kThreads,
layout::PitchLinearShape<8, 4>,
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 0,
IteratorThreadMapB>;
//
// Warp-level matrix multiply operator
//
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
/// Cache operation of operand E
static cutlass::arch::CacheOperation::Kind const kCacheOpE =
cutlass::arch::CacheOperation::Global;
static int const kInterleavedE = MmaTensorOp::kInterleaved;
static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits;
static int const kMaxID2 = MmaTensorOp::kMaxID2;
static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE;
using ElementE = typename MmaTensorOp::ElementE;
using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved<kInterleavedE>;
// Shared memory layout. Interleaved layout is mapped to PitchLinear layout.
using SmemLayoutE = typename MmaTensorOp::LayoutE;
/// ThreadMap of iterator E
static int const kElementsPerAccessE =
kAccessSizeInBits / sizeof_bits<ElementE>::value;
/// E is tiny. Not all warps are needed.
static int const kThreadsE =
(Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
(kAccessSizeInBits / sizeof_bits<ElementE>::value) >
kThreads)
? kThreads
: (Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
(kAccessSizeInBits / sizeof_bits<ElementE>::value));
using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<Shape::kM * kInterleavedE,
Shape::kK / kSparse / kElementsPerElementE /
kInterleavedE>,
kThreadsE, kElementsPerAccessE>;
/// Shared memory iterator to E operand
using SmemIteratorE = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM * kInterleavedE,
Shape::kK / kSparse / kElementsPerElementE / kInterleavedE>,
ElementE, SmemLayoutE, 0, IteratorThreadMapE>;
/// Policy used to define MmaPipelined
using MmaPolicy =
SparseMmaPolicy<MmaTensorOp, MatrixShape<0, 0>, MatrixShape<0, 0>,
MatrixShape<0, 0>, WarpCount::kK>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: row-major
/// B: column-major
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of B operand
typename ElementB_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Number of stages
int Stages,
/// Operation performed by MMA
typename Operator_,
/// Cache operation of operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB>
struct DefaultSparseMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
layout::RowMajor, ElementB_, layout::ColumnMajor,
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
Operator_, false, CacheOpA, CacheOpB> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using ElementA = ElementA_;
using LayoutA = layout::RowMajor;
using ElementB = ElementB_;
using LayoutB = layout::ColumnMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
static int const kStages = Stages;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
static int const kSparse = 2;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
Shape::kK / WarpShape::kK>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
/// Number of threads per warp
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
/// Size of a threadblock-scoped access
static int const kAccessSizeInBits = 128;
/// Default Operator
using Operator = Operator_;
// Warp thread arrangement
static int const kWarpThreadArrangementContiguousA =
Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
static int const kWarpThreadArrangementStridedA =
kWarpSize / kWarpThreadArrangementContiguousA;
// crosswise cannot be larger than 1024 bit.
static int const kCrosswiseB =
(Shape::kK > (1024 / sizeof_bits<ElementB>::value))
? (1024 / sizeof_bits<ElementB>::value)
: Shape::kK;
static int const kWarpThreadArrangementContiguousB =
kCrosswiseB / (kAccessSizeInBits / sizeof_bits<ElementB>::value);
static int const kWarpThreadArrangementStridedB =
kWarpSize / kWarpThreadArrangementContiguousB;
//
// Shared memory layouts
//
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementA>::value, Shape::kK / kSparse>;
// Shared memory layout
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementB>::value, kCrosswiseB>;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK / kSparse, Shape::kM>, kThreads,
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
kWarpThreadArrangementStridedA>,
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM, Shape::kK / kSparse>, ElementA, SmemLayoutA, 0,
IteratorThreadMapA>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreads,
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
kWarpThreadArrangementStridedB>,
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
IteratorThreadMapB>;
//
// Warp-level matrix multiply operator
//
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
/// Cache operation of operand E
static cutlass::arch::CacheOperation::Kind const kCacheOpE =
cutlass::arch::CacheOperation::Global;
static int const kInterleavedE = MmaTensorOp::kInterleaved;
static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits;
static int const kMaxID2 = MmaTensorOp::kMaxID2;
static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE;
using ElementE = typename MmaTensorOp::ElementE;
using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved<kInterleavedE>;
// Shared memory layout. Interleaved layout is mapped to PitchLinear layout.
using SmemLayoutE = typename MmaTensorOp::LayoutE;
/// ThreadMap of iterator E
static int const kElementsPerAccessE =
kAccessSizeInBits / sizeof_bits<ElementE>::value;
/// E is tiny. Not all warps are needed.
static int const kThreadsE =
(Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
(kAccessSizeInBits / sizeof_bits<ElementE>::value) >
kThreads)
? kThreads
: (Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
(kAccessSizeInBits / sizeof_bits<ElementE>::value));
using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<Shape::kM * kInterleavedE,
Shape::kK / kSparse / kElementsPerElementE /
kInterleavedE>,
kThreadsE, kElementsPerAccessE>;
/// Shared memory iterator to E operand
using SmemIteratorE = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM * kInterleavedE,
Shape::kK / kSparse / kElementsPerElementE / kInterleavedE>,
ElementE, SmemLayoutE, 0, IteratorThreadMapE>;
/// Policy used to define MmaPipelined
using MmaPolicy =
SparseMmaPolicy<MmaTensorOp, MatrixShape<0, 0>, MatrixShape<0, 0>,
MatrixShape<0, 0>, WarpCount::kK>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: column-major
/// B: column-major
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of B operand
typename ElementB_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Number of stages
int Stages,
/// Operation performed by MMA
typename Operator_,
/// Cache operation of operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB>
struct DefaultSparseMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
layout::ColumnMajor, ElementB_, layout::ColumnMajor,
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
Operator_, false, CacheOpA, CacheOpB> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using ElementA = ElementA_;
using LayoutA = layout::ColumnMajor;
using ElementB = ElementB_;
using LayoutB = layout::ColumnMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
static int const kStages = Stages;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
static int const kSparse = 2;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
Shape::kK / WarpShape::kK>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
/// Number of threads per warp
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
/// Size of a threadblock-scoped access
static int const kAccessSizeInBits = 128;
/// Default Operator
using Operator = Operator_;
// Warp thread arrangement
// crosswise cannot be larger than 1024 bit.
static int const kCrosswiseB =
(Shape::kK > (1024 / sizeof_bits<ElementB>::value))
? (1024 / sizeof_bits<ElementB>::value)
: Shape::kK;
static int const kWarpThreadArrangementContiguousB =
kCrosswiseB / (kAccessSizeInBits / sizeof_bits<ElementB>::value);
static int const kWarpThreadArrangementStridedB =
kWarpSize / kWarpThreadArrangementContiguousB;
//
// Shared memory layouts
//
using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous<
sizeof_bits<ElementA>::value, int(128 / sizeof(ElementA))>;
// Shared memory layout
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementB>::value, kCrosswiseB>;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kM, Shape::kK / kSparse>, kThreads,
layout::PitchLinearShape<8, 4>,
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM, Shape::kK / kSparse>, ElementA, SmemLayoutA, 1,
IteratorThreadMapA>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreads,
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
kWarpThreadArrangementStridedB>,
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
IteratorThreadMapB>;
//
// Warp-level matrix multiply operator
//
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
/// Cache operation of operand E
static cutlass::arch::CacheOperation::Kind const kCacheOpE =
cutlass::arch::CacheOperation::Global;
static int const kInterleavedE = MmaTensorOp::kInterleaved;
static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits;
static int const kMaxID2 = MmaTensorOp::kMaxID2;
static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE;
using ElementE = typename MmaTensorOp::ElementE;
using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved<kInterleavedE>;
// Shared memory layout. Interleaved layout is mapped to PitchLinear layout.
using SmemLayoutE = typename MmaTensorOp::LayoutE;
/// ThreadMap of iterator E
static int const kElementsPerAccessE =
kAccessSizeInBits / sizeof_bits<ElementE>::value;
/// E is tiny. Not all warps are needed.
static int const kThreadsE =
(Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
(kAccessSizeInBits / sizeof_bits<ElementE>::value) >
kThreads)
? kThreads
: (Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
(kAccessSizeInBits / sizeof_bits<ElementE>::value));
using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<Shape::kM * kInterleavedE,
Shape::kK / kSparse / kElementsPerElementE /
kInterleavedE>,
kThreadsE, kElementsPerAccessE>;
/// Shared memory iterator to E operand
using SmemIteratorE = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM * kInterleavedE,
Shape::kK / kSparse / kElementsPerElementE / kInterleavedE>,
ElementE, SmemLayoutE, 0, IteratorThreadMapE>;
/// Policy used to define MmaPipelined
using MmaPolicy =
SparseMmaPolicy<MmaTensorOp, MatrixShape<0, 0>, MatrixShape<0, 0>,
MatrixShape<0, 0>, WarpCount::kK>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: row-major
/// B: row-major
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of B operand
typename ElementB_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Number of stages
int Stages,
/// Operation performed by MMA
typename Operator_,
/// Cache operation of operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB>
struct DefaultSparseMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
layout::RowMajor, ElementB_, layout::RowMajor, ElementC_,
LayoutC_, arch::OpClassTensorOp, Stages, Operator_,
false, CacheOpA, CacheOpB> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using ElementA = ElementA_;
using LayoutA = layout::RowMajor;
using ElementB = ElementB_;
using LayoutB = layout::RowMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
static int const kStages = Stages;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
static int const kSparse = 2;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
Shape::kK / WarpShape::kK>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
/// Number of threads per warp
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
/// Size of a threadblock-scoped access
static int const kAccessSizeInBits = 128;
/// Default Operator
using Operator = Operator_;
// Warp thread arrangement
static int const kWarpThreadArrangementContiguousA =
Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
static int const kWarpThreadArrangementStridedA =
kWarpSize / kWarpThreadArrangementContiguousA;
//
// Shared memory layouts
//
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementA>::value, Shape::kK / kSparse>;
// Shared memory layout
using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous<
sizeof_bits<ElementB>::value, int(128 / sizeof(ElementB))>;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK / kSparse, Shape::kM>, kThreads,
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
kWarpThreadArrangementStridedA>,
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM, Shape::kK / kSparse>, ElementA, SmemLayoutA, 0,
IteratorThreadMapA>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kN, Shape::kK>, kThreads,
layout::PitchLinearShape<8, 4>,
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 0,
IteratorThreadMapB>;
//
// Warp-level matrix multiply operator
//
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
/// Cache operation of operand E
static cutlass::arch::CacheOperation::Kind const kCacheOpE =
cutlass::arch::CacheOperation::Global;
static int const kInterleavedE = MmaTensorOp::kInterleaved;
static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits;
static int const kMaxID2 = MmaTensorOp::kMaxID2;
static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE;
using ElementE = typename MmaTensorOp::ElementE;
using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved<kInterleavedE>;
// Shared memory layout. Interleaved layout is mapped to PitchLinear layout.
using SmemLayoutE = typename MmaTensorOp::LayoutE;
/// ThreadMap of iterator E
static int const kElementsPerAccessE =
kAccessSizeInBits / sizeof_bits<ElementE>::value;
/// E is tiny. Not all warps are needed.
static int const kThreadsE =
(Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
(kAccessSizeInBits / sizeof_bits<ElementE>::value) >
kThreads)
? kThreads
: (Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
(kAccessSizeInBits / sizeof_bits<ElementE>::value));
using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<Shape::kM * kInterleavedE,
Shape::kK / kSparse / kElementsPerElementE /
kInterleavedE>,
kThreadsE, kElementsPerAccessE>;
/// Shared memory iterator to E operand
using SmemIteratorE = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM * kInterleavedE,
Shape::kK / kSparse / kElementsPerElementE / kInterleavedE>,
ElementE, SmemLayoutE, 0, IteratorThreadMapE>;
/// Policy used to define MmaPipelined
using MmaPolicy =
SparseMmaPolicy<MmaTensorOp, MatrixShape<0, 0>, MatrixShape<0, 0>,
MatrixShape<0, 0>, WarpCount::kK>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,190 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/wmma.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h"
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include "cutlass/gemm/threadblock/default_mma_core_wmma.h"
#endif //CUTLASS_ARCH_WMMA_ENABLED
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation perfomed by GEMM
typename Operator,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false
>
struct DefaultSparseMma;
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp)
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Number of stages used in the multistage mainloop
int Stages,
/// Operation perfomed by GEMM
typename Operator
>
struct DefaultSparseMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
kAlignmentB, ElementAccumulator, layout::RowMajor,
arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape,
InstructionShape, Stages, Operator, false> {
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
Stages, Operator, false, CacheOpA, CacheOpB>;
static int const kSparse = MmaCore::kSparse;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK / kSparse>,
ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>;
// Define iterators over tiles from the E operand
using ElementE = typename MmaCore::ElementE;
using LayoutE = typename MmaCore::GmemLayoutE;
using ThreadMapE = typename MmaCore::IteratorThreadMapE;
using AccessTypeE =
cutlass::Array<ElementE, 128 / sizeof_bits<ElementE>::value>;
using IteratorE =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM,
ThreadblockShape::kK / kSparse /
MmaCore::kElementsPerElementE>,
ElementE, LayoutE, 1, ThreadMapE, AccessTypeE>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::SparseMmaMultistage<
typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor,
IteratorE, typename MmaCore::SmemIteratorE, MmaCore::kCacheOpE,
typename MmaCore::MmaPolicy, Stages>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,259 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Policy object describing MmaTensorOp
template <
/// Warp-level GEMM operator (concept: gemm::warp::Mma)
typename Operator_,
/// Padding used for A operand in shared memory (concept: MatrixShape)
typename SmemPaddingA_,
/// Padding used for B operand in shared memory (concept: MatrixShape)
typename SmemPaddingB_,
/// Padding used for E operand in shared memory (concept: MatrixShape)
typename SmemPaddingE_,
/// Number of partitions of K dimension of GEMM
int PartitionsK = 1>
struct SparseMmaPolicy {
/// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt)
using Operator = Operator_;
/// Padding used for A operand in shared memory
using SmemPaddingA = SmemPaddingA_;
/// Padding used for B operand in shared memory
using SmemPaddingB = SmemPaddingB_;
/// Padding used for B operand in shared memory
using SmemPaddingE = SmemPaddingE_;
/// Number of partitions of K dimension
static int const kPartitionsK = PartitionsK;
};
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class SparseMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
static int const kSparse = Operator::kSparse;
static int const kElementsPerElementE = Operator::kElementsPerElementE;
/// Tensor reference to the A operand
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
/// Tensor reference to the E operand
using TensorRefE = TensorRef<typename Operator::ElementE, typename Operator::LayoutE>;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK / kSparse * kStages +
Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB =
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
/// Shape of the E matrix operand in shared memory
using ShapeE =
MatrixShape<Shape::kM * 2 + Policy::SmemPaddingE::kRow,
Shape::kK / kSparse / kElementsPerElementE / 2 * kStages +
Policy::SmemPaddingE::kColumn>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer for E operand
AlignedBuffer<typename Operator::ElementE, ShapeE::kCount> operand_E;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a layout object for the E matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutE LayoutE() {
return Operator::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
/// Returns a TensorRef to the E operand
CUTLASS_HOST_DEVICE
TensorRefE operand_E_ref() {
return TensorRefE{operand_E.data(), LayoutE()};
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
/// Iterator to load a warp-scoped tile of E operand from shared memory
typename Operator::IteratorE warp_tile_iterator_E_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
SparseMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx),
warp_tile_iterator_E_(shared_storage.operand_E_ref(), lane_idx) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,667 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/threadblock/mma_sparse_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Iterates over tiles of E operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorE_,
/// Iterates over tiles of E operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorE_,
/// Cache operation for operand E
cutlass::arch::CacheOperation::Kind CacheOpE,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class SparseMmaMultistage :
public SparseMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = SparseMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Iterates over tiles of E operand in global memory
using IteratorE = IteratorE_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorE = SmemIteratorE_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
static cutlass::arch::CacheOperation::Kind const kCacheOpE = CacheOpE;
static int const kSparse = Policy::Operator::kSparse;
static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits;
static int const kMaxID2 = Policy::Operator::kMaxID2;
static int const kElementsPerElementE =
Policy::Operator::kElementsPerElementE;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// ElementE
using ElementE = typename IteratorE::Element;
/// LayoutE
using LayoutE = typename IteratorE::Layout;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
static_assert(Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of async copies to load one stage of operand A
static int const TBLDGSTSIterationsA =
IteratorA::ThreadMap::Iterations::kCount;
/// Number of async copies to load one stage of operand B
static int const TBLDGSTSIterationsB =
IteratorB::ThreadMap::Iterations::kCount;
/// Number of async copies to load one stage of operand E
static int const TBLDGSTSIterationsE =
IteratorE::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of async copies to load one group of operand A
static int const kAccessesPerGroupA =
(TBLDGSTSIterationsA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
/// Number of async copies to load one group of operand B
static int const kAccessesPerGroupB =
(TBLDGSTSIterationsB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
/// Number of async copies to load one group of operand E
static int const kAccessesPerGroupE =
(TBLDGSTSIterationsE + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
/// E operand is tiny. For the most of time, not all the warps are needed
/// to load it from the global memory.
static int const kValidWarps = IteratorE::ThreadMap::kThreads / 32;
/// B operand is twice as big as A which brings very high register pressure.
/// We have to sacrifice the double buffer when the warp tile size is big.
static int const kBBufferSize =
((sizeof(typename Operator::ElementC) == 4) &&
((platform::is_same<typename Operator::Policy::Operator::ElementA,
typename Operator::ElementA>::value &&
platform::is_same<typename Operator::Policy::Operator::ElementB,
typename Operator::ElementB>::value)) &&
(Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64))
? 1
: 2;
};
private:
using WarpLoadedFragmentA = typename Operator::FragmentA;
using WarpLoadedFragmentB = typename Operator::FragmentB;
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
using WarpFragmentE = typename Operator::FragmentE;
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of E operand to shared memory
SmemIteratorE smem_iterator_E_;
/// Warp id
bool is_warp_valid_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
SparseMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
smem_iterator_E_(shared_storage.operand_E_ref(), thread_idx)
{
is_warp_valid_ = warp_idx < Detail::kValidWarps;
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
this->warp_tile_iterator_E_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
}
CUTLASS_DEVICE
void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B,
IteratorE &iterator_E, int group_start_A = 0,
int group_start_B = 0, int group_start_E = 0) {
iterator_A.set_iteration_index(group_start_A *
IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// async copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
if (group_start_A + j < Detail::TBLDGSTSIterationsA) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// async copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
if (group_start_B + j < Detail::TBLDGSTSIterationsB) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}
}
iterator_E.set_iteration_index(group_start_E);
this->smem_iterator_E_.set_iteration_index(group_start_E);
// async copy for operand E
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupE; ++j) {
if (group_start_E + j < Detail::TBLDGSTSIterationsE) {
typename IteratorE::AccessType *dst_ptr =
reinterpret_cast<typename IteratorE::AccessType *>(
this->smem_iterator_E_.get());
int const kSrcBytes = sizeof_bits<typename IteratorE::Element>::value *
IteratorE::ThreadMap::kElementsPerAccess / 8;
auto gmem_ptr = iterator_E.get();
cutlass::arch::cp_async<kSrcBytes, kCacheOpE>(
dst_ptr, gmem_ptr, iterator_E.valid() && is_warp_valid_);
++iterator_E;
++this->smem_iterator_E_;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC &accum,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< iterator over E operand in global memory
IteratorE iterator_E,
///< initial value of accumulator
FragmentC const &src_accum) {
//
// Prologue
//
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
iterator_E.clear_mask();
}
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// async copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::TBLDGSTSIterationsA; ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
// async copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::TBLDGSTSIterationsB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}
iterator_E.set_iteration_index(0);
this->smem_iterator_E_.set_iteration_index(0);
// async copy for operand E
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::TBLDGSTSIterationsE; ++j) {
typename IteratorE::AccessType *dst_ptr =
reinterpret_cast<typename IteratorE::AccessType *>(
this->smem_iterator_E_.get());
int const kSrcBytes = sizeof_bits<typename IteratorE::Element>::value *
IteratorE::ThreadMap::kElementsPerAccess / 8;
if (is_warp_valid_)
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpE>(
dst_ptr, iterator_E.get(), iterator_E.valid());
++iterator_E;
++this->smem_iterator_E_;
}
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
iterator_E.add_tile_offset({0, 1});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
this->smem_iterator_E_.add_tile_offset({0, 1});
// LDGDEPBAR - completes a stage
cutlass::arch::cp_async_fence();
}
// Perform accumulation in the 'd' output operand
accum = src_accum;
// DEPBAR+SYNC
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpLoadedFragmentA warp_loaded_frag_A[2];
WarpLoadedFragmentB warp_loaded_frag_B[Detail::kBBufferSize];
WarpTransformedFragmentA warp_transformed_frag_A[2];
WarpTransformedFragmentB warp_transformed_frag_B[Detail::kBBufferSize];
WarpFragmentE warp_frag_E[2];
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_E_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
this->warp_tile_iterator_E_.load(warp_frag_E[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
++this->warp_tile_iterator_E_;
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
iterator_E.clear_mask();
}
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0],
warp_loaded_frag_A[0], warp_loaded_frag_B[0]);
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);) {
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_E_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_E_.load(warp_frag_E[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_E_;
if (Detail::kBBufferSize == 2) {
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(
warp_loaded_frag_B[(warp_mma_k + 1) % Detail::kBBufferSize]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k > 0)
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize],
warp_loaded_frag_A[warp_mma_k % 2],
warp_loaded_frag_B[warp_mma_k % Detail::kBBufferSize]);
warp_mma(
accum,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize], accum,
warp_frag_E[warp_mma_k % 2]
);
if (Detail::kBBufferSize == 1) {
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
++this->warp_tile_iterator_B_;
}
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
group_start_iteration_E = warp_mma_k * Detail::kAccessesPerGroupE;
copy_tiles_and_advance(
iterator_A, iterator_B, iterator_E, group_start_iteration_A,
group_start_iteration_B, group_start_iteration_E);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E;
group_start_iteration_A =
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B =
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
group_start_iteration_E =
(warp_mma_k + 1) * Detail::kAccessesPerGroupE;
copy_tiles_and_advance(
iterator_A, iterator_B, iterator_E, group_start_iteration_A,
group_start_iteration_B, group_start_iteration_E);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages have committed.
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
iterator_E.add_tile_offset({0, 1});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
this->smem_iterator_E_.add_tile_offset({0, 1});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1)) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
this->smem_iterator_E_.add_tile_offset({0, -Base::kStages});
smem_write_stage_idx = 0;
} else {
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1)) {
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations,
0});
this->warp_tile_iterator_E_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
smem_read_stage_idx = 0;
} else {
++smem_read_stage_idx;
}
--gemm_k_iterations;
if (gemm_k_iterations == 0) {
iterator_A.clear_mask();
iterator_B.clear_mask();
iterator_E.clear_mask();
}
}
// Do any conversions feeding the first stage at the end of the loop so
// we can start right away on mma instructions
if (warp_mma_k + 1 == Base::kWarpGemmIterations)
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -123,16 +123,22 @@ struct GemmIdentityThreadblockSwizzle {
/// Computes CUDA grid dimensions given a size in units of logical tiles
CUTLASS_HOST_DEVICE
dim3 get_grid_shape(GemmCoord tiled_shape) const {
if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile))
return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k());
}
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
CUTLASS_DEVICE
GemmCoord get_tile_offset() const {
GemmCoord get_tile_offset(GemmCoord tiled_shape) const {
int block_idx_x = RematerializeBlockIdxX();
int block_idx_y = RematerializeBlockIdxY();
if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile))
return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()};
return GemmCoord{
(block_idx_x / kTile),
(block_idx_y * kTile) + (block_idx_x % kTile),
@ -170,7 +176,7 @@ struct GemmHorizontalThreadblockSwizzle {
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
CUTLASS_DEVICE
GemmCoord get_tile_offset() const {
GemmCoord get_tile_offset(GemmCoord tiled_shape) const {
return GemmCoord{
RematerializeBlockIdxY(),
RematerializeBlockIdxX(),
@ -205,7 +211,7 @@ struct GemmBatchedIdentityThreadblockSwizzle {
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
CUTLASS_DEVICE
GemmCoord get_tile_offset() const {
GemmCoord get_tile_offset(GemmCoord tiled_shape) const {
return GemmCoord{
RematerializeBlockIdxX(),
RematerializeBlockIdxY(),
@ -244,17 +250,23 @@ struct GemmSplitKIdentityThreadblockSwizzle {
/// Computes CUDA grid dimensions given a size in units of logical tiles
CUTLASS_HOST_DEVICE
dim3 get_grid_shape(GemmCoord tiled_shape) const {
if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile))
return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k());
}
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
CUTLASS_DEVICE
GemmCoord get_tile_offset() const {
GemmCoord get_tile_offset(GemmCoord tiled_shape) const {
int block_idx_x = RematerializeBlockIdxX();
int block_idx_y = RematerializeBlockIdxY();
if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile))
return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()};
return GemmCoord{
(block_idx_x / kTile),
(block_idx_y * kTile) + (block_idx_x % kTile),
@ -290,7 +302,7 @@ struct GemmSplitKHorizontalThreadblockSwizzle {
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
CUTLASS_DEVICE
GemmCoord get_tile_offset() const {
GemmCoord get_tile_offset(GemmCoord tiled_shape) const {
return GemmCoord{
RematerializeBlockIdxY(),
RematerializeBlockIdxX(),

View File

@ -0,0 +1,159 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/mma_sparse_tensor_op.h"
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename ElementB_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Operator describing the tensor operation
typename Operator_ = arch::OpMultiplyAdd,
/// Number of partitions along K dimension
int PartitionsK = 1,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false
>
struct DefaultSparseMmaTensorOp;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial Specialization - inputs and output types are float - uses TF32 internally
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Shape of target matrix multiply instruction (concept: GemmShape)
typename InstructionShape_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultSparseMmaTensorOp<
WarpShape_,
InstructionShape_,
float, LayoutA,
float, LayoutB,
float, LayoutC,
arch::OpMultiplyAdd, PartitionsK, AccumulatorsInRowMajor> {
// Uses TF32 internally
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::SparseMma<
InstructionShape_,
32,
tfloat32_t, cutlass::layout::RowMajor,
tfloat32_t, cutlass::layout::ColumnMajor,
float, cutlass::layout::RowMajor,
arch::OpMultiplyAdd
>,
cutlass::MatrixShape<1, 1> >;
// Define the warp-level tensor op
using Type = cutlass::gemm::warp::SparseMmaTensorOp<
WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC,
Policy, PartitionsK, AccumulatorsInRowMajor>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for m-by-n-by-kgroup
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A elements
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Data type of B elements
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Operator describing the tensor operation
typename Operator_,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultSparseMmaTensorOp {
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::SparseMma<InstructionShape_, 32, ElementA,
cutlass::layout::RowMajor, ElementB,
cutlass::layout::ColumnMajor, ElementC,
cutlass::layout::RowMajor, Operator_>,
cutlass::MatrixShape<1, 1> >;
// Define the warp-level tensor op
using Type = cutlass::gemm::warp::SparseMmaTensorOp<
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
Policy, PartitionsK, AccumulatorsInRowMajor>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -147,6 +147,9 @@ public:
dp4a_type
>;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename ThreadMma::ArchMmaOperator;
/// Shape of the underlying instruction
using InstructionShape = GemmShape<1,1,use_dp4a ? 4 : 1>;

View File

@ -0,0 +1,335 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate
operations targeting sparse Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/platform/platform.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename ElementB_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Number of partitions along K dimension
int PartitionsK_ = 1,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false,
/// Used for partial specialization
typename Enable = bool
>
class SparseMmaTensorOp {
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = ElementB_;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Architecture tag from underlying instruction
using ArchTag = typename Policy::Operator::ArchTag;
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
/// Number of partitions along K dimension
static int const kPartitionsK = PartitionsK_;
/// Sparsity in Operand A
static int const kSparse = Policy::Operator::kSparse;
/// Meta data size in bits
static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits;
/// Max ID2
static int const kMaxID2 = Policy::Operator::kMaxID2;
/// Data type of meta E that is moved at the same time
using ElementE =
typename cutlass::platform::conditional<kMaxID2 == 1, uint32_t,
uint16_t>::type;
/// Number of ElementA that is associated with one ElementE
static int const kElementsPerElementE =
128 / cutlass::sizeof_bits<ElementA>::value;
/// Meta data is essentially interleaved but mapped to ColumnMajor internally
static int const kInterleaved = 2;
/// Layout of meta E
using LayoutE = cutlass::layout::ColumnMajor;
public:
/// Iterates over the A operand in memory
using IteratorA = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kM, Shape::kK / kSparse>, Operand::kA, ElementA,
LayoutA,
MatrixShape<Policy::Operator::Shape::kM,
Policy::Operator::Shape::kK / kSparse>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile
using TransformedFragmentA =
Array<typename Policy::Operator::ElementA, FragmentA::kElements>;
/// Iterates over the B operand in memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed B tile
using TransformedFragmentB =
Array<typename Policy::Operator::ElementB, FragmentB::kElements>;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<
MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
typename Policy::Operator::Shape, typename Policy::OpDelta>;
/// Storage for C tile
using FragmentC = typename IteratorC::Fragment;
/// Iterates over the E operand in memory
using IteratorE = SparseMmaTensorOpMetaTileIterator<
MatrixShape<Shape::kM * kInterleaved,
Shape::kK / kSparse / kElementsPerElementE / kInterleaved>,
ElementE, LayoutE,
MatrixShape<Policy::Operator::Shape::kM,
Policy::Operator::Shape::kK / kSparse / kElementsPerElementE /
kInterleaved>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for E tile
using FragmentE = typename IteratorE::Fragment;
private:
static_assert(
!(Shape::kM % Policy::Operator::Shape::kM) &&
!(Shape::kN % Policy::Operator::Shape::kN),
"Shape of warp-level Mma must be divisible by operator shape.");
/// Number of mma operations performed
using MmaIterations = MatrixShape<
Shape::kM / Policy::Operator::Shape::kM,
Shape::kN / Policy::Operator::Shape::kN
>;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
typename Policy::Operator mma;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
SparseMmaTensorOp() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(
FragmentC &D,
TransformedFragmentA const &A,
TransformedFragmentB const &B,
FragmentC const &C,
FragmentE const &E
) const {
using MmaOperandA = typename Policy::Operator::FragmentA;
using MmaOperandB = typename Policy::Operator::FragmentB;
using MmaOperandC = typename Policy::Operator::FragmentC;
using MmaOperandE = typename Policy::Operator::FragmentE;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
D = C;
MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);
MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);
MmaOperandE const *ptr_E = reinterpret_cast<MmaOperandE const *>(&E);
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m) {
int id2 = m % kMaxID2;
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n) {
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
if (AccumulatorsInRowMajor) { // matrix B is reordered
mma(
ptr_D[n_serpentine + m * MmaIterations::kColumn],
ptr_A[m],
ptr_B[n_serpentine],
ptr_D[n_serpentine + m * MmaIterations::kColumn],
ptr_E[(m / kMaxID2)],
id2);
} else {
mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
ptr_A[m],
ptr_B[n_serpentine],
ptr_D[m + n_serpentine * MmaIterations::kRow],
ptr_E[(m / kMaxID2)],
id2);
}
}
}
#else
assert(0);
#endif
}
/// Transform the mma operands to the required types
CUTLASS_DEVICE
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
FragmentA const &A, FragmentB const &B) const {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
//
// Define conversions from source type to instruction type
//
FloatRoundStyle const kRoundA =
PreferredRoundingMode<typename Policy::Operator::ElementA,
ElementA>::kRound;
FloatRoundStyle const kRoundB =
PreferredRoundingMode<typename Policy::Operator::ElementB,
ElementB>::kRound;
detail::ConvertAndPack<typename Policy::Operator::ElementA, ElementA,
FragmentA::kElements / 2, kRoundA>
convert_A;
NumericArrayConverter<typename Policy::Operator::ElementB, ElementB,
FragmentB::kElements, kRoundB>
convert_B;
Array<ElementA, FragmentA::kElements / 2> const *ptr_A =
reinterpret_cast<Array<ElementA, FragmentA::kElements / 2> const *>(&A);
Array<typename Policy::Operator::ElementA, FragmentA::kElements / 2> *
ptr_dst_A = reinterpret_cast<Array<typename Policy::Operator::ElementA,
FragmentA::kElements / 2> *>(&dst_A);
dst_B = convert_B(B);
ptr_dst_A[0] = convert_A(ptr_A[0]);
ptr_dst_A[1] = convert_A(ptr_A[1]);
#else
assert(0);
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -184,14 +184,17 @@ public:
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename Policy::Operator;
/// Architecture tag from underlying instruction
using ArchTag = typename Policy::Operator::ArchTag;
using ArchTag = typename ArchMmaOperator::ArchTag;
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Shape of underlying instruction
using InstructionShape = typename Policy::Operator::Shape;
using InstructionShape = typename ArchMmaOperator::Shape;
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
@ -210,7 +213,7 @@ public:
/// Iterates over the A operand in memory
using IteratorA = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for A tile
@ -218,12 +221,12 @@ public:
/// Storage for transformed A tile
using TransformedFragmentA =
Array<typename Policy::Operator::ElementA, FragmentA::kElements>;
Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
/// Iterates over the B operand in memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for B tile
@ -231,33 +234,28 @@ public:
/// Storage for transformed B tile
using TransformedFragmentB =
Array<typename Policy::Operator::ElementB, FragmentB::kElements>;
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<
MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
typename Policy::Operator::Shape, typename Policy::OpDelta>;
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
/// Storage for C tile
using FragmentC = typename IteratorC::Fragment;
private:
static_assert(
!(Shape::kM % Policy::Operator::Shape::kM) &&
!(Shape::kN % Policy::Operator::Shape::kN),
"Shape of warp-level Mma must be divisible by operator shape.");
/// Number of mma operations performed
using MmaIterations = MatrixShape<
Shape::kM / Policy::Operator::Shape::kM,
Shape::kN / Policy::Operator::Shape::kN
(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN
>;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
typename Policy::Operator mma;
ArchMmaOperator mma;
public:
@ -278,9 +276,9 @@ public:
FragmentC const &C
) const {
using MmaOperandA = typename Policy::Operator::FragmentA;
using MmaOperandB = typename Policy::Operator::FragmentB;
using MmaOperandC = typename Policy::Operator::FragmentC;
using MmaOperandA = typename ArchMmaOperator::FragmentA;
using MmaOperandB = typename ArchMmaOperator::FragmentB;
using MmaOperandC = typename ArchMmaOperator::FragmentC;
D = C;
@ -351,22 +349,22 @@ public:
// Define conversions from source type to instruction type
//
FloatRoundStyle const kRoundA =
PreferredRoundingMode<typename Policy::Operator::ElementA,
PreferredRoundingMode<typename ArchMmaOperator::ElementA,
ElementA>::kRound;
FloatRoundStyle const kRoundB =
PreferredRoundingMode<typename Policy::Operator::ElementB,
PreferredRoundingMode<typename ArchMmaOperator::ElementB,
ElementB>::kRound;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
detail::ConvertAndPack<typename Policy::Operator::ElementA, ElementA,
detail::ConvertAndPack<typename ArchMmaOperator::ElementA, ElementA,
FragmentA::kElements, kRoundA>
convert_A;
NumericArrayConverter<typename Policy::Operator::ElementB, ElementB,
NumericArrayConverter<typename ArchMmaOperator::ElementB, ElementB,
FragmentB::kElements / 2, kRoundB>
convert_B;
Array<ElementB, FragmentB::kElements / 2> const *ptr_B =
reinterpret_cast<Array<ElementB, FragmentB::kElements / 2> const *>(&B);
Array<typename Policy::Operator::ElementB, FragmentB::kElements / 2> *
ptr_dst_B = reinterpret_cast<Array<typename Policy::Operator::ElementB,
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements / 2> *
ptr_dst_B = reinterpret_cast<Array<typename ArchMmaOperator::ElementB,
FragmentB::kElements / 2> *>(&dst_B);
dst_A = convert_A(A);
@ -375,16 +373,16 @@ public:
ptr_dst_B[1] = convert_B(ptr_B[1]);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
detail::ConvertAndPack<typename Policy::Operator::ElementA, ElementA,
detail::ConvertAndPack<typename ArchMmaOperator::ElementA, ElementA,
FragmentA::kElements / 2, kRoundA>
convert_A;
NumericArrayConverter<typename Policy::Operator::ElementB, ElementB,
NumericArrayConverter<typename ArchMmaOperator::ElementB, ElementB,
FragmentB::kElements, kRoundB>
convert_B;
Array<ElementA, FragmentA::kElements / 2> const *ptr_A =
reinterpret_cast<Array<ElementA, FragmentA::kElements / 2> const *>(&A);
Array<typename Policy::Operator::ElementA, FragmentA::kElements / 2> *
ptr_dst_A = reinterpret_cast<Array<typename Policy::Operator::ElementA,
Array<typename ArchMmaOperator::ElementA, FragmentA::kElements / 2> *
ptr_dst_A = reinterpret_cast<Array<typename ArchMmaOperator::ElementA,
FragmentA::kElements / 2> *>(&dst_A);
dst_B = convert_B(B);

View File

@ -291,7 +291,9 @@ class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Ele
!(Shape::kColumn % InstructionShape::kN),
"Shape of warp-level Mma must be divisible by operator shape.");
static_assert(
!(AccumulatorShape::kRow % Shape::kRow) &&
AccumulatorShape::kRow == Shape::kRow,
"Rows of Warp Accumulator must be the same as rows of warp");
static_assert(
!(AccumulatorShape::kColumn % Shape::kColumn),
"Shape of Warp Accumulator must be divisible by warp shape.");
static_assert(
@ -304,7 +306,16 @@ class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Ele
private:
static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads;
static int const kRowsPerIteration = 8;
static int const kColumnsPerIteration = 16;
static int const kElementsPerIteration = kRowsPerIteration * InstructionShape::kN / kThreads;
static int const kElementsPerAccess = kRowsPerIteration * kColumnsPerIteration / kThreads;
static int const kIterationsPerAccess = kElementsPerAccess / kElementsPerIteration;
// Number of iterations per actual instruction
static int const kIterationsPerInstruction = InstructionShape::kM / kRowsPerIteration;
static int const kAccessStride = kIterationsPerInstruction;
/// Number of mma operations performed by a warp
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
@ -313,13 +324,15 @@ private:
using AccumulatorIterations = MatrixShape<AccumulatorShape::kRow / InstructionShape::kM,
AccumulatorShape::kColumn / InstructionShape::kN>;
/// Number of Accesses in a warp
using AccessIterations = MatrixShape<MmaIterations::kRow * kIterationsPerInstruction,
MmaIterations::kColumn / kIterationsPerAccess>;
/// Number of K iterations
static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn;
static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn;
static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn
* (AccumulatorShape::kRow / Shape::kRow);
static int const kResidualIndex = kResidualColumn / Shape::kColumn
* (AccumulatorShape::kRow / Shape::kRow);
static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn;
static int const kResidualIndex = kResidualColumn / Shape::kColumn;
public:
@ -338,8 +351,8 @@ public:
private:
/// Internal access type
using AccessType = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentAccessType = Array<Element, kElementsPerAccess>;
using AccessType = Array<ElementAccumulator, kElementsPerIteration>;
using FragmentAccessType = Array<Element, kElementsPerIteration>;
private:
//
@ -386,6 +399,11 @@ public:
return *this;
}
CUTLASS_HOST_DEVICE
void set_index(int idx) {
index_ = idx;
}
/// Loads a fragment from the referenced part of the accumulator tile
CUTLASS_HOST_DEVICE
void load(Fragment &frag, OutputOp output_op) const {
@ -399,21 +417,35 @@ public:
FragmentAccessType *frag_ptr = reinterpret_cast<FragmentAccessType *>(&frag);
// NumericArrayConverter<Element, ElementAccumulator, kElementsPerAccess, FloatRoundStyle::round_indeterminate> fragmentConverter;
int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow;
int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow
* MmaIterations::kColumn;
int index = index_ * AccessIterations::kCount;
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; m++) {
for (int n = 0; n < MmaIterations::kColumn; n++) {
int accumulator_access_offset =
(m + index_m) * AccumulatorIterations::kColumn + n + index_n;
for (int i = 0; i < AccessIterations::kCount; i++) {
// int index_m = (index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction)
// * kIterationsPerInstruction + index % kIterationsPerInstruction;
//
// int index_n = (index / AccessIterations::kCount) * MmaIterations::kColumn +
// (index % (AccessIterations::kColumn * kIterationsPerInstruction))
// / kIterationsPerInstruction * AccessIterations::kColumn;
//
// int accumulator_access_offset = index_m / kIterationsPerInstruction * AccessIterations::kCount * kIterationsPerInstruction
// + index_m % kIterationsPerInstruction + index_n * kIterationsPerInstruction;
frag_ptr[m * MmaIterations::kColumn + n].clear();
int accumulator_access_offset = index / AccessIterations::kCount * (MmaIterations::kColumn * kIterationsPerInstruction) +
(index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction) *
AccumulatorIterations::kColumn * kIterationsPerInstruction +
(index % (AccessIterations::kColumn * kIterationsPerInstruction)) / kIterationsPerInstruction *
(kIterationsPerInstruction * kIterationsPerAccess) +
(index % kIterationsPerInstruction);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kIterationsPerAccess; j++) {
frag_ptr[i*kIterationsPerAccess + j].clear();
if(!(is_residual_tile_ && index_ >= kResidualIndex))
// frag_ptr[m * MmaIterations::kColumn + n] = fragmentConverter(accumulators_[accumulator_access_offset]);
frag_ptr[m * MmaIterations::kColumn + n] = output_op(accumulators_[accumulator_access_offset], src_fragment);
// frag_ptr[m * MmaIterations::kColumn + n] = fragmentConverter(accumulators_[accumulator_access_offset]);
frag_ptr[i*kIterationsPerAccess + j] = output_op(accumulators_[accumulator_access_offset + j * kAccessStride], src_fragment);
}
index++;
}
}

View File

@ -106,8 +106,11 @@ public:
/// Architecture tag
using ArchTag = arch::Sm70;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename Policy::Operator;
/// Underlying instruction shape
using InstructionShape = typename Policy::Operator::Shape;
using InstructionShape = typename ArchMmaOperator::Shape;
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
@ -133,8 +136,8 @@ public:
ElementA,
LayoutA,
MatrixShape<
Policy::Operator::Shape::kM,
Policy::Operator::Shape::kK
ArchMmaOperator::Shape::kM,
ArchMmaOperator::Shape::kK
>,
Policy::OpDelta::kRow,
kThreadCount
@ -150,8 +153,8 @@ public:
ElementB,
LayoutB,
MatrixShape<
Policy::Operator::Shape::kK,
Policy::Operator::Shape::kN
ArchMmaOperator::Shape::kK,
ArchMmaOperator::Shape::kN
>,
Policy::OpDelta::kRow,
kThreadCount
@ -165,7 +168,7 @@ public:
MatrixShape<Shape::kM, Shape::kN>,
ElementC,
LayoutC,
typename Policy::Operator::Shape,
typename ArchMmaOperator::Shape,
typename Policy::OpDelta
>;
@ -175,14 +178,14 @@ public:
private:
static_assert(
!(Shape::kM % Policy::Operator::Shape::kM) &&
!(Shape::kN % Policy::Operator::Shape::kN),
!(Shape::kM % ArchMmaOperator::Shape::kM) &&
!(Shape::kN % ArchMmaOperator::Shape::kN),
"Shape of warp-level Mma must be divisible by operator shape.");
/// Number of mma operations performed
using MmaIterations = MatrixShape<
InterleavedTileShape::kM / Policy::Operator::Shape::kM,
InterleavedTileShape::kN / Policy::Operator::Shape::kN
InterleavedTileShape::kM / ArchMmaOperator::Shape::kM,
InterleavedTileShape::kN / ArchMmaOperator::Shape::kN
>;
using TileIterations = MatrixShape<
Shape::kM / InterleavedTileShape::kM,
@ -195,7 +198,7 @@ private:
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
typename Policy::Operator mma;
ArchMmaOperator mma;
public:
@ -215,9 +218,9 @@ public:
FragmentB const &B,
FragmentC const &C) {
using MmaOperandA = typename Policy::Operator::FragmentA;
using MmaOperandB = typename Policy::Operator::FragmentB;
using MmaOperandC = typename Policy::Operator::FragmentC;
using MmaOperandA = typename ArchMmaOperator::FragmentA;
using MmaOperandB = typename ArchMmaOperator::FragmentB;
using MmaOperandC = typename ArchMmaOperator::FragmentC;
D = C;

View File

@ -241,6 +241,9 @@ public:
int access_strided_idx = -1;
if (Policy::LdsmShape::kContiguous == 4) {
// Matrix multiply 1688 A/B
// Q0 Q1 Q2 Q3 (Q stands for 1 8x128bit block).
// Four blocks are next to each other in the contiguous dimension.
partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ i);
access_contiguous_idx = (quad_pair ^ lane_in_quad);
access_strided_idx = lane_in_quad_pair;
@ -262,7 +265,17 @@ public:
partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1));
access_contiguous_idx = ((quad_quad + ((i & 1) << 1)) ^ lane_in_quad);
access_strided_idx = lane_in_quad_quad;
} else if (Policy::LdsmShape::kContiguous == 1) {
// Matrix multiply 16832.SP B
// Q0
// Q1
// Q2
// Q3
partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 2));
access_contiguous_idx = ((i & 3) ^ lane_in_quad);
access_strided_idx = lane_id;
}
int access_contiguous =
partition_contiguous_idx * Layout::PartitionShape::kContiguous +
access_contiguous_idx;
@ -531,24 +544,24 @@ class MmaTensorOpMultiplicandTileIterator<
!(Shape::kContiguous % InstructionShape::kContiguous),
"Shape of warp-level Mma must be divisible by operator shape.");
// Determine number of elements along outer dimension per individual LDS.32
// op. Every one warp of LDS.32 loads 8x4 elements
// Determine number of elements along outer dimension per individual 32bit
// shared memory load op. Every one warp of 32bit shared memory load loads
// 8x4 elements
static int const kLdsOpInner = Layout::TileShape::kStrided;
static int const kLdsOpOuter = kThreads / kLdsOpInner;
static_assert(!(Shape::kContiguous % kLdsOpOuter),
"Shape of warp-level mma must be divisible by LDS.32's "
"Shape of warp-level mma must be divisible by 32bit "
"fundamental tile size.");
static_assert(!(Shape::kStrided % kLdsOpInner),
"Shape of warp-level mma must be divisible by LDS.32's "
"Shape of warp-level mma must be divisible by 32bit "
"fundamental tile size.");
/// Number of LDS.32 instructions needed by one MMA instruction
/// 1684 A 2x1
/// 1684 B 1x1
/// 1688 A 2x2
/// 1688 B 1x2
/// Number of 32 bit shared memory load instructions needed by one MMA instruction
/// 1688 A 2x2
/// 1688 B 1x2
/// 16816 B 1x4
static int const LdsShapeContiguous =
InstructionShape::kContiguous / kLdsOpOuter;
static int const LdsShapeStrided = InstructionShape::kStrided / kLdsOpInner;
@ -639,6 +652,8 @@ class MmaTensorOpMultiplicandTileIterator<
if (Shape::kContiguous ==
Layout::TileShape::kContiguous * Layout::kElementsPerAccess / 2) {
if (tile_offset.contiguous() % 2) {
// Matrix multiply 1688 pointer_[0] <=> pointer_[4] pointer_[1] <=> pointer_[5]
// pointer_[2] <=> pointer_[6] pointer_[3] <=> pointer_[7]
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kPointerCount / 2; ++i) {
AccessType const *tmp_pointer = pointer_[i];
@ -1535,6 +1550,14 @@ class MmaTensorOpMultiplicandTileIterator<
access_strided_idx =
(lane_in_quad_pair + (lane_id >> 4 << 3)) / Layout::kFactor;
}
else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) {
// Matrix multiply 16832.SP B
// Q0 Q1 Q2 Q3
partition_contiguous_idx = (lane_id % Layout::kFactor);
access_contiguous_idx =
(quad_pair ^ (lane_in_quad_pair / Layout::kFactor));
access_strided_idx = lane_in_quad_pair / Layout::kFactor;
}
} else if (Layout::kFactor == 1) {
// Super Matrix multiply kBlock = 64
if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) {
@ -1565,6 +1588,13 @@ class MmaTensorOpMultiplicandTileIterator<
access_contiguous_idx = ((quad_pair & 1) ^ lane_in_quad);
access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3);
}
else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) {
// Matrix multiply 16832.SP B
// Q0 Q1 Q2 Q3
partition_contiguous_idx = (lane_in_quad_pair >> 2);
access_contiguous_idx = (quad_pair ^ lane_in_quad);
access_strided_idx = lane_in_quad_pair;
}
}
int access_contiguous =
@ -2369,17 +2399,18 @@ class MmaTensorOpAccumulatorTileIterator<
/// Internal structure of iterator - made public to enable introspection
struct Policy {
static_assert(
static bool const kDivisible =
!(Shape::kRow % InstructionShape::kM) &&
!(Shape::kColumn % InstructionShape::kN),
"Shape of warp-level Mma must be divisible by operator shape.");
!(Shape::kColumn % InstructionShape::kN);
static_assert(platform::is_same<TensorCoord, MatrixCoord>::value,
"Layouts must be defined for logical MatrixCoord coordinate space.");
/// Number of mma operations performed
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
Shape::kColumn / InstructionShape::kN>;
using MmaIterations = MatrixShape<
(Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM,
(Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN
>;
};
private:
@ -2398,7 +2429,9 @@ public:
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads>;
using Fragment = Array<
Element,
Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>;
private:
@ -2667,17 +2700,18 @@ class MmaTensorOpAccumulatorTileIterator<Shape_, Element_,
/// Internal structure of iterator - made public to enable introspection
struct Policy {
static_assert(
static bool const kDivisible =
!(Shape::kRow % InstructionShape::kM) &&
!(Shape::kColumn % InstructionShape::kN),
"Shape of warp-level Mma must be divisible by operator shape.");
!(Shape::kColumn % InstructionShape::kN);
static_assert(platform::is_same<TensorCoord, MatrixCoord>::value,
"Layouts must be defined for logical MatrixCoord coordinate space.");
/// Number of mma operations performed
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
Shape::kColumn / InstructionShape::kN>;
using MmaIterations = MatrixShape<
(Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM,
(Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN
>;
};
private:
@ -2696,7 +2730,8 @@ public:
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads>;
using Fragment = Array<Element,
Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>;
private:

View File

@ -1570,6 +1570,839 @@ public:
}
};
////////////////////////////////////////////////////////////////////////////////
/// Tile iterator specialized for canonical matrix layouts
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Operand identity
Operand Operand_,
/// Data type of A elements
typename Element_,
/// Layout of operand
typename Layout_,
/// Shape of one matrix production operation (concept: MatrixShape)
typename InstructionShape_,
/// Delta between *MMA operations (in units of *MMA operations, concept:
/// MatrixShape)
int OpDelta_,
/// Number of threads participating in one matrix operation
int Threads = 32,
/// Number of partitions along K dimension
int PartitionsK_ = 1>
class MmaTensorOpMultiplicandTileIteratorCanonical {
public:
/// Shape of tile to load (concept: MatrixShape)
using Shape = Shape_;
/// Operand tag
static Operand const kOperand = Operand_;
/// Basic check
static_assert(kOperand == Operand::kA || kOperand== Operand::kB,
"MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma.");
/// Element type
using Element = Element_;
/// Layout of source tile
using Layout = Layout_;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = InstructionShape_;
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
static int const kOpDelta = OpDelta_;
/// Number of participating threads
static int const kThreads = 32;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<Element, Layout>;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
/// Number of elements accessed per Shared Memory load
static int const kElementsPerAccess =
(sizeof_bits<Element>::value >= 32 ? 1 : 32 / sizeof_bits<Element>::value);
private:
static int const kWarpShapeOuter =
(kOperand == Operand::kA ? Shape::kRow : Shape::kColumn);
static int const kWarpShapeInner =
(kOperand == Operand::kA ? Shape::kColumn : Shape::kRow);
/// Rounded up instruction counts
using InstructionCount = MatrixShape<
Shape::kRow / InstructionShape::kRow,
Shape::kColumn / InstructionShape::kColumn
>;
/// Rounded up tile dimensions
using WarpShapeDivisible = MatrixShape<
InstructionCount::kRow * InstructionShape::kRow,
InstructionCount::kColumn * InstructionShape::kColumn
>;
public:
//
// Derived quantities
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<
Element,
WarpShapeDivisible::kRow * WarpShapeDivisible::kColumn / kThreads
>;
/// Memory access type
using AccessType = AlignedArray<Element, kElementsPerAccess>;
private:
/// Underlying tensor reference
TensorRef ref_;
/// Extent of tensor
MatrixCoord extent_;
/// Origin
MatrixCoord origin_;
/// Used to conditionally enable extents checking
bool divisible_;
public:
/// Default ctor constructs null iterator
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIteratorCanonical(): divisible_(true) { }
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIteratorCanonical(
TensorRef const &ref,
int lane_id
): ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) {
if (kOperand == Operand::kA) {
origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess);
}
else {
origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4);
}
ref_.add_coord_offset(origin_);
}
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIteratorCanonical(
TensorRef const &ref,
TensorCoord extent,
int lane_id
): ref_(ref), extent_(extent), divisible_(false) {
if (kOperand == Operand::kA) {
origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess);
}
else {
origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4);
}
ref_.add_coord_offset(origin_);
}
/// Adds a pointer offset to internal pointer(s) to advance through memory
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIteratorCanonical &add_pointer_offset(LongIndex offset) {
ref_.add_pointer_offset(offset);
return *this;
}
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIteratorCanonical &add_tile_offset(TensorCoord const &tile_offset) {
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
origin_ += coord_offset;
ref_.add_coord_offset(coord_offset);
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_DEVICE
MmaTensorOpMultiplicandTileIteratorCanonical & operator++() {
if (kOperand == Operand::kA) {
add_tile_offset({0, 1});
}
else {
add_tile_offset({1, 0});
}
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIteratorCanonical & operator--() {
if (kOperand == Operand::kA) {
add_tile_offset({0, -1});
}
else {
add_tile_offset({-1, 0});
}
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_DEVICE
MmaTensorOpMultiplicandTileIteratorCanonical & operator+=(TensorCoord const &tile_offset) {
add_tile_offset(tile_offset);
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_DEVICE
MmaTensorOpMultiplicandTileIteratorCanonical & operator-=(TensorCoord const &tile_offset) {
add_tile_offset(-tile_offset);
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
load_with_pointer_offset(frag, 0);
}
/// Loads a fragment from memory with additional logical offset
CUTLASS_DEVICE
void load_with_pointer_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a linear offset
Index pointer_offset) const {
int const kWarpShapeDivisibleInner =
(kOperand == Operand::kA ? WarpShapeDivisible::kColumn : WarpShapeDivisible::kRow);
// Take advantage of Tensor Op's 8 x 4T access pattern
int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
AccessType *access_ptr = reinterpret_cast<AccessType *>(&frag);
if (kOperand == Operand::kA) {
int const kTilesPerInstruction = InstructionShape::kRow / 8;
CUTLASS_PRAGMA_UNROLL
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
CUTLASS_PRAGMA_UNROLL
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) {
int access_idx =
access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx);
MatrixCoord offset(
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
inner_idx * 4 * kElementsPerAccess);
MatrixCoord access_coord = origin_ + offset;
if (divisible_ ||
(access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) {
access_ptr[access_idx] = *reinterpret_cast<AccessType const *>(
ref_.data() + ref_.offset(offset));
}
else {
AccessType zero;
zero.clear();
access_ptr[access_idx] = zero;
}
}
}
}
}
else {
CUTLASS_PRAGMA_UNROLL
for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
int access_idx = inner_idx + kAccessesInner * inst_n_idx;
MatrixCoord offset(
inner_idx * 4 * kElementsPerAccess,
inst_n_idx * 8);
MatrixCoord access_coord = origin_ + offset;
if (divisible_ ||
(access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) {
access_ptr[access_idx] = *reinterpret_cast<AccessType const *>(
ref_.data() + ref_.offset(offset));
}
else {
AccessType zero;
zero.clear();
access_ptr[access_idx] = zero;
}
}
}
}
}
/// Loads a fragment from memory with additional logical offset
CUTLASS_DEVICE
void load_with_byte_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a linear offset
Index byte_offset) const {
load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits<Element>::value);
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset) const {
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
load_with_pointer_offset(frag, ref_.offset(coord_offset));
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset,
/// loads a tile with a logical offset AND a pointer offset
Index pointer_offset) const {
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset);
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load_with_byte_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset,
/// loads a tile with a logical offset AND a pointer offset
Index byte_offset) const {
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits<Element>::value);
}
/// Notify the iterator which k-group it is currently pointing to.
///
/// This does not advance the iterator. Rather, it overrides its internal
/// tracking with constant-valued k-group index to enable the compiler to
/// fold constants and achieve more efficient code.
///
/// This is used by some nontrivial permuted layouts.
CUTLASS_DEVICE
void set_kgroup_index(int k_group) {
// no operation
}
};
/// Wrapper for ColumnMajor
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Identifies A or B multiplicand
Operand Operand_,
/// Data type of elements
typename Element_,
/// Shape of one matrix product operation (concept: MatrixShape)
typename InstructionShape_,
/// Interval between adjacent *MMA instructions (in units of MMA
/// instructions)
int OpDelta_,
/// Number of partitions along K dimension
int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
Shape_, Operand_, Element_,
cutlass::layout::ColumnMajor,
InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
/// Shape of tile to load (concept: PitchLinearShape)
using Shape = Shape_;
/// Operand tag
static Operand const kOperand = Operand_;
static_assert(kOperand == Operand::kA || kOperand== Operand::kB,
"MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma.");
/// Element type
using Element = Element_;
/// Layout of source tile
using Layout = cutlass::layout::ColumnMajor;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = InstructionShape_;
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
static int const kOpDelta = OpDelta_;
/// Number of participating threads
static int const kThreads = 32;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<Element, Layout>;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
/// Underlying tile iterator implementation
using Base = MmaTensorOpMultiplicandTileIteratorCanonical<
Shape, kOperand, Element,
layout::ColumnMajor,
InstructionShape,
kOpDelta, kThreads, PartitionsK_>;
public:
//
// Derived quantities
//
/// Fragment object holding a thread's part of a tile
using Fragment = typename Base::Fragment;
private:
/// Underlying tile iterator
Base iterator_;
public:
/// Default ctor constructs null iterator
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator() { }
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator(
TensorRef const &ref,
int lane_id
): iterator_({ref.data(), ref.stride()}, lane_id) {
}
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator(
TensorRef const &ref,
TensorCoord const & extent,
int lane_id
): iterator_({ref.data(), ref.stride()}, extent, lane_id) {
}
/// Adds a pointer offset to internal pointer(s) to advance through memory
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) {
iterator_.add_pointer_offset(offset);
return *this;
}
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) {
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator & operator++() {
++iterator_;
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator & operator--() {
--iterator_;
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_DEVICE
MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) {
add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column()));
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_DEVICE
MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) {
add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column()));
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
iterator_.load(frag);
}
/// Loads a fragment from memory with additional logical offset
CUTLASS_DEVICE
void load_with_pointer_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a linear offset
Index pointer_offset) const {
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory with additional logical offset
CUTLASS_DEVICE
void load_with_byte_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a linear offset
Index byte_offset) const {
iterator_.load_with_byte_offset(frag, byte_offset);
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset) const {
// TODO
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset,
/// loads a tile with a logical offset AND a pointer offset
Index pointer_offset) const {
// TODO
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load_with_byte_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset,
/// loads a tile with a logical offset AND a pointer offset
Index byte_offset) const {
iterator_.load_with_byte_offset(
frag,
{tile_offset.contiguous(), tile_offset.strided()},
byte_offset);
}
/// Notify the iterator which k-group it is currently pointing to.
///
/// This does not advance the iterator. Rather, it overrides its internal
/// tracking with constant-valued k-group index to enable the compiler to
/// fold constants and achieve more efficient code.
///
/// This is used by some nontrivial permuted layouts.
CUTLASS_DEVICE
void set_kgroup_index(int k_group) {
iterator_.set_kgroup_index(k_group);
}
};
/// Wrapper for RowMajor
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Identifies A or B multiplicand
Operand Operand_,
/// Data type of elements
typename Element_,
/// Shape of one matrix product operation (concept: MatrixShape)
typename InstructionShape_,
/// Interval between adjacent *MMA instructions (in units of MMA
/// instructions)
int OpDelta_,
/// Number of partitions along K dimension
int PartitionsK_>
class MmaTensorOpMultiplicandTileIterator<
Shape_, Operand_, Element_,
cutlass::layout::RowMajor,
InstructionShape_, OpDelta_, 32, PartitionsK_> {
public:
/// Shape of tile to load (concept: PitchLinearShape)
using Shape = Shape_;
/// Operand tag
static Operand const kOperand = Operand_;
static_assert(kOperand == Operand::kA || kOperand== Operand::kB,
"MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma.");
/// Element type
using Element = Element_;
/// Layout of source tile
using Layout = cutlass::layout::RowMajor;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = InstructionShape_;
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
static int const kOpDelta = OpDelta_;
/// Number of participating threads
static int const kThreads = 32;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<Element, Layout>;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
/// Underlying tile iterator implementation
using Base = MmaTensorOpMultiplicandTileIteratorCanonical<
Shape, kOperand, Element,
layout::RowMajor,
InstructionShape,
kOpDelta, kThreads, PartitionsK_>;
public:
//
// Derived quantities
//
/// Fragment object holding a thread's part of a tile
using Fragment = typename Base::Fragment;
private:
/// Underlying tile iterator
Base iterator_;
public:
/// Default ctor constructs null iterator
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator() { }
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator(
TensorRef const &ref,
int lane_id
): iterator_({ref.data(), ref.stride()}, lane_id) {
}
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator(
TensorRef const &ref,
TensorCoord const &extent,
int lane_id
): iterator_({ref.data(), ref.stride()}, extent, lane_id) {
}
/// Adds a pointer offset to internal pointer(s) to advance through memory
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) {
iterator_.add_pointer_offset(offset);
return *this;
}
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) {
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator & operator++() {
++iterator_;
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_HOST_DEVICE
MmaTensorOpMultiplicandTileIterator & operator--() {
--iterator_;
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_DEVICE
MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) {
add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column()));
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_DEVICE
MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) {
add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column()));
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
iterator_.load(frag);
}
/// Loads a fragment from memory with additional logical offset
CUTLASS_DEVICE
void load_with_pointer_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a linear offset
Index pointer_offset) const {
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory with additional logical offset
CUTLASS_DEVICE
void load_with_byte_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a linear offset
Index byte_offset) const {
iterator_.load_with_byte_offset(frag, byte_offset);
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset) const {
// TODO
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset,
/// loads a tile with a logical offset AND a pointer offset
Index pointer_offset) const {
// TODO
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load_with_byte_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset,
/// loads a tile with a logical offset AND a pointer offset
Index byte_offset) const {
iterator_.load_with_byte_offset(
frag,
{tile_offset.contiguous(), tile_offset.strided()},
byte_offset);
}
/// Notify the iterator which k-group it is currently pointing to.
///
/// This does not advance the iterator. Rather, it overrides its internal
/// tracking with constant-valued k-group index to enable the compiler to
/// fold constants and achieve more efficient code.
///
/// This is used by some nontrivial permuted layouts.
CUTLASS_DEVICE
void set_kgroup_index(int k_group) {
iterator_.set_kgroup_index(k_group);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace warp

View File

@ -0,0 +1,374 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators to load sparse meta data used by warp-level matrix multiply operations
targeting Sparse Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
#include "cutlass/platform/platform.h"
#include "cutlass/fast_math.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
////////////////////////////////////////////////////////////////////////////////
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Data type of A elements
typename Element_,
/// Layout of operand
typename Layout_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Delta between *MMA operations (in units of *MMA operations, concept:
/// MatrixShape)
int OpDelta_,
/// Number of threads participating in one matrix operation
int Threads,
/// Number of partitions along K dimension
int PartitionsK_ = 1>
class SparseMmaTensorOpMetaTileIterator {
public:
/// Shape of tile to load (concept: PitchLinearShape)
using Shape = Shape_;
/// Element type
using Element = Element_;
/// Layout of source tile
using Layout = Layout_;
/// Shape of one matrix product operation (concept: GemmShape)
using InstructionShape = InstructionShape_;
/// Delta between *MMA operations (in units of *MMA operations, concept:
/// MatrixShape)
static int const kOpDelta = OpDelta_;
/// Number of participating threads
static int const kThreads = 32;
/// Number of partitions along K dimension
static int const kPartitionsK = PartitionsK_;
static int const kSparse = 2;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<Element, Layout>;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
/// Internal structure of iterator - made public to enable introspection
struct Policy {
static_assert(
!(Shape::kColumn % InstructionShape::kColumn),
"Shape of warp-level Mma must be divisible by operator shape.");
static int const kElementsPerAccess = 128 / sizeof_bits<Element>::value;
// Determine number of elements along outer dimension per individual LDSM op
static int const kLdsmOpOuter = InstructionShape::kColumn;
static int const kLdsmOpInner = 8 * kElementsPerAccess / kLdsmOpOuter;
static_assert(!(Shape::kColumn % kLdsmOpOuter),
"Shape of warp-level mma must be divisible by LDSM's "
"fundamental tile size.");
static_assert(!(Shape::kRow % kLdsmOpInner),
"Shape of warp-level mma must be divisible by LDSM's "
"fundamental tile size.");
/// Shape of one individual LDSM instruction
static int const LdsmShapeColumn =
InstructionShape::kColumn / kLdsmOpOuter;
static int const LdsmShapeRow =
((4 / LdsmShapeColumn * kLdsmOpInner) > Shape::kRow)
? (Shape::kRow / kLdsmOpInner)
: (4 / LdsmShapeColumn);
using LdsmShape =
layout::PitchLinearShape<LdsmShapeRow, LdsmShapeColumn>;
/// Number and arrangement of LDSM instructions
using LdsmIterations = layout::PitchLinearShape<
Shape::kRow / kLdsmOpInner / LdsmShapeRow,
1>;
/// Number of groups for each tile
static int const kGroupsPerTile =
Shape::kColumn / InstructionShape::kColumn;
};
private:
/// Not working on this feature at the moment.
static_assert(kOpDelta == 1,
"Alternative arrangements not supported at present.");
/// Pointer type used for accesses
using AccessType = Array<Element, Policy::kElementsPerAccess>;
public:
//
// Derived quantities
//
/// Fragment object holding a thread's part of a tile
using Fragment =
Array<Element, Shape::kRow * InstructionShape::kColumn / kThreads>;
private:
/// Layout object storing stride values
Index stride_;
/// Shared memory base pointers - not advanced
AccessType const *pointer_;
/// Byte offset incremented as iterator advances
Index byte_offset_;
/// Internal counter used to determine when to increment byte offset and when
/// to XOR it
int k_group_idx_;
public:
/// Default ctor constructs null iterator
CUTLASS_HOST_DEVICE
SparseMmaTensorOpMetaTileIterator()
: pointer_(nullptr),
stride_(0),
byte_offset_(0),
k_group_idx_(0) {}
/// Constructor from TensorRef
CUTLASS_DEVICE
SparseMmaTensorOpMetaTileIterator(TensorRef const &ref, int lane_id)
: pointer_(reinterpret_cast<AccessType const *>(ref.data())),
stride_(ref.stride(0) / Policy::kElementsPerAccess),
byte_offset_(0),
k_group_idx_(0) {
int access_contiguous = (lane_id % (Shape::kRow / Policy::kElementsPerAccess));
int access_strided = (lane_id / (Shape::kRow / Policy::kElementsPerAccess));
byte_offset_ = (access_contiguous + access_strided * stride_) *
sizeof_bits<Element>::value * Policy::kElementsPerAccess / 8;
}
/// Adds a pointer offset to internal pointer(s) to advance through memory
CUTLASS_DEVICE
SparseMmaTensorOpMetaTileIterator &add_pointer_offset(LongIndex offset) {
byte_offset_ += offset * sizeof_bits<Element>::value / 8;
return *this;
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_DEVICE
SparseMmaTensorOpMetaTileIterator &add_tile_offset(
TensorCoord const &tile_offset) {
int offset = tile_offset.row() * Shape::kRow +
tile_offset.column() * InstructionShape::kColumn * stride_ *
Policy::kElementsPerAccess;
add_pointer_offset(offset);
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_DEVICE
SparseMmaTensorOpMetaTileIterator &operator++() {
add_tile_offset({0, 1});
if (kPartitionsK > 1) {
++k_group_idx_;
// Jump to next stage
if (k_group_idx_ == Policy::kGroupsPerTile) {
k_group_idx_ = 0;
add_tile_offset(
{0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)});
}
}
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_HOST_DEVICE
SparseMmaTensorOpMetaTileIterator &operator--(){
byte_offset_ -= stride_ * InstructionShape::kColumn *
sizeof_bits<Element>::value * Policy::kElementsPerAccess /
8;
}
///< advances in units of whole tiles along the logical coordinate space of
///< the tensor
CUTLASS_DEVICE SparseMmaTensorOpMetaTileIterator &
operator+=(TensorCoord const &tile_offset) {
add_tile_offset(tile_offset);
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of
///< the tensor
CUTLASS_DEVICE
SparseMmaTensorOpMetaTileIterator &operator-=(
TensorCoord const &tile_offset) {
add_tile_offset(-tile_offset);
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const { load_with_byte_offset(frag, 0); }
/// Loads a fragment from memory with additional logical offset
CUTLASS_DEVICE
void load_with_byte_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a linear offset in units of bytes
Index byte_offset) const {
Array<unsigned, Policy::LdsmShape::kCount> *fetch_ptr =
reinterpret_cast<Array<unsigned, Policy::LdsmShape::kCount> *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) {
int access_idx = c + s * Policy::LdsmIterations::kContiguous;
AccessType const *source_ptr =
pointer_ +
Policy::LdsmShape::kContiguous * Policy::kLdsmOpInner * c +
Policy::LdsmShape::kStrided * s * stride_;
char const *source_byte_ptr = reinterpret_cast<char const *>(source_ptr) +
byte_offset + byte_offset_;
cutlass::arch::ldsm<layout::RowMajor, Policy::LdsmShape::kCount>(
fetch_ptr[access_idx], source_byte_ptr);
}
}
}
/// Loads a fragment from memory with additional logical offset
CUTLASS_DEVICE
void load_with_pointer_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a linear offset
Index pointer_offset) const {
load_with_byte_offset(frag, pointer_offset * sizeof(Element));
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset) const {
load_with_byte_offset(frag, tile_offset, 0);
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset,
/// loads a tile with a logical offset AND a pointer offset
Index pointer_offset) const {
load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element));
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load_with_byte_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset,
/// loads a tile with a logical offset AND a pointer offset
Index byte_offset) const {
Index pointer_offset =
tile_offset.contiguous() * Shape::kRow / Layout::kElementsPerAccess +
tile_offset.strided() * InstructionShape::kColumn * stride_;
byte_offset += sizeof(AccessType) * pointer_offset;
load_with_byte_offset(frag, byte_offset);
}
/// Notify the iterator which k-group it is currently pointing to.
///
/// This does not advance the iterator. Rather, it overrides its internal
/// tracking with constant-valued k-group index to enable the compiler to
/// fold constants and achieve more efficient code.
///
/// This is used by some nontrivial permuted layouts.
CUTLASS_DEVICE
void set_kgroup_index(int k_group) {
// no op
}
};
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -349,7 +349,7 @@ struct alignas(2) half_t {
/// Default constructor
CUTLASS_HOST_DEVICE
half_t() { }
half_t() : storage(0) { }
/// Reinterpret cast from CUDA's half type
CUTLASS_HOST_DEVICE

View File

@ -83,11 +83,10 @@ struct integer_subbyte {
integer_subbyte(unsigned value)
: storage(reinterpret_cast<Storage const &>(value) & kMask) {}
/// Conversion from double
CUTLASS_HOST_DEVICE
integer_subbyte(double value) {
T tmp = (T)value;
storage = reinterpret_cast<Storage const &>(tmp) & kMask;
T tmp = static_cast<T>(value);
storage = Storage(reinterpret_cast<unsigned const &>(tmp) & kMask);
}
///
@ -155,6 +154,12 @@ struct integer_subbyte {
/// 1-bit Unsigned integer type
using uint1b_t = integer_subbyte<1, false>;
/// 2-bit Integer type
using int2b_t = integer_subbyte<2, true>;
/// 2-bit Unsigned integer type
using uint2b_t = integer_subbyte<2, false>;
/// 4-bit Integer type
using int4b_t = integer_subbyte<4, true>;
@ -169,6 +174,18 @@ struct sizeof_bits<uint1b_t> {
static int const value = 1;
};
/// Defines the size of an element in bits - specialized for int2b_t
template <>
struct sizeof_bits<int2b_t> {
static int const value = 2;
};
/// Defines the size of an element in bits - specialized for uint2b_t
template <>
struct sizeof_bits<uint2b_t> {
static int const value = 2;
};
/// Defines the size of an element in bits - specialized for int4b_t
template <>
struct sizeof_bits<int4b_t> {

View File

@ -35,7 +35,6 @@
#include "cutlass/cutlass.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/matrix_traits.h"
namespace cutlass {
namespace layout {
@ -803,7 +802,7 @@ private:
// Data members
//
MatrixLayout layout_id_;
Matrix layout_id_;
/// Stride data member
Stride stride_;
@ -815,12 +814,12 @@ public:
/// Ctor
CUTLASS_HOST_DEVICE
GeneralMatrix(): layout_id_(MatrixLayout::kColumnMajor), stride_(make_Coord(0, 1)) { }
GeneralMatrix(): layout_id_(Matrix::kColumnMajor), stride_(make_Coord(0, 1)) { }
/// Ctor
CUTLASS_HOST_DEVICE
GeneralMatrix(
MatrixLayout layout_id,
Matrix layout_id,
Index ldm,
Index interleave): layout_id_(layout_id), stride_(make_Coord(ldm, interleave)) { }
@ -828,11 +827,11 @@ public:
CUTLASS_HOST_DEVICE
static GeneralMatrix packed(
MatrixCoord const &extent,
MatrixLayout layout_id = MatrixLayout::kColumnMajor,
Matrix layout_id = Matrix::kColumnMajor,
Index interleave = 1) {
Index c;
if (layout_id == MatrixLayout::kRowMajor) {
if (layout_id == Matrix::kRowMajor) {
c = extent.column();
}
else {
@ -849,7 +848,7 @@ public:
CUTLASS_HOST_DEVICE
LongIndex operator()(MatrixCoord const &coord) const {
Index c, s;
if (layout_id_ == MatrixLayout::kRowMajor) {
if (layout_id_ == Matrix::kRowMajor) {
c = coord.column();
s = coord.row();
}
@ -871,7 +870,7 @@ public:
}
CUTLASS_HOST_DEVICE
MatrixLayout layout_id() const {
Matrix layout_id() const {
return layout_id_;
}
@ -882,7 +881,7 @@ public:
}
CUTLASS_HOST_DEVICE
MatrixLayout & layout_id() {
Matrix & layout_id() {
return layout_id_;
}
@ -902,7 +901,7 @@ public:
CUTLASS_HOST_DEVICE
LongIndex capacity(MatrixCoord const &extent) const {
Index s;
if (layout_id_ == MatrixLayout::kRowMajor) {
if (layout_id_ == Matrix::kRowMajor) {
s = extent.row();
}
else {

View File

@ -79,7 +79,7 @@ private:
// Data members
//
/// Stride data member - [c, wc, hwc]
/// Stride data member - [stride_w, stride_h, stride_n]
Stride stride_;
public:
@ -93,7 +93,12 @@ public:
/// Constructor
CUTLASS_HOST_DEVICE
TensorNHWC(typename Stride::Index c, typename Stride::Index wc, typename Stride::Index hwc): stride_(make_Coord(c, wc, hwc)) { }
TensorNHWC(
typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates
typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates
typename Stride::Index stride_n ///< number of elements between adjacent N coordinates
):
stride_(make_Coord(stride_w, stride_h, stride_n)) { }
/// Helper returns a layout to a tightly packed NHWC tensor.
CUTLASS_HOST_DEVICE
@ -116,12 +121,6 @@ public:
LongIndex(stride_[2] * coord.n());
}
/// Returns a RowMajor equivalent for a TensorNHWC layout
CUTLASS_HOST_DEVICE
explicit operator RowMajor() {
return RowMajor(stride_[0]);
}
/// Returns the logical coordinate (n, h, w, c) from a given offset in linear memory.
CUTLASS_HOST_DEVICE
TensorCoord inverse(LongIndex index) const {
@ -444,6 +443,107 @@ public:
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Mapping function for 5-D NDHWC tensors.
class TensorNDHWC {
public:
/// Logical rank of tensor
static int const kRank = 5;
/// Rank of stride vector
static int const kStrideRank = 4;
/// Index type used for coordinates
using Index = int32_t;
/// Long index type used for offsets
using LongIndex = int64_t;
/// Logical coordinate (n, d, h, w, c)
using TensorCoord = Tensor5DCoord;
/// Stride vector
using Stride = Coord<kStrideRank>;
private:
//
// Data members
//
/// Stride data member - [c, wc, hwc, dhwc]
Stride stride_;
public:
//
// Methods
//
/// Constructor
CUTLASS_HOST_DEVICE
TensorNDHWC(Stride const &stride = Stride(0)): stride_(stride) { }
/// Constructor
CUTLASS_HOST_DEVICE
TensorNDHWC(
typename Stride::Index c,
typename Stride::Index wc,
typename Stride::Index hwc,
typename Stride::Index dhwc):
stride_(make_Coord(c, wc, hwc, dhwc)) { }
/// Helper returns a layout to a tightly packed NHWC tensor.
CUTLASS_HOST_DEVICE
static TensorNDHWC packed(TensorCoord const &extent) {
return TensorNDHWC(
make_Coord(
extent.c(),
extent.w() * extent.c(),
extent.h() * extent.w() * extent.c(),
extent.d() * extent.h() * extent.w() * extent.c()
)
);
}
/// Returns the offset of a coordinate (n, d, h, w, c) in linear memory.
CUTLASS_HOST_DEVICE
LongIndex operator()(TensorCoord const &coord) const {
return coord.c() +
LongIndex(stride_[0] * coord.w()) +
LongIndex(stride_[1] * coord.h()) +
LongIndex(stride_[2] * coord.d()) +
LongIndex(stride_[3] * coord.n());
}
/// Returns the stride of the layout
CUTLASS_HOST_DEVICE
Stride stride() const {
return stride_;
}
/// Returns the stride of the layout
CUTLASS_HOST_DEVICE
Stride & stride() {
return stride_;
}
/// Compute the number of contiguous elements needed to store a tensor with the given size
CUTLASS_HOST_DEVICE
LongIndex capacity(TensorCoord const &extent) const {
// it does not make sense if the extent is larger than stride
// and we could not rely on the capacity calculation in such cases
// we could move this checkers to debug code only
if ((extent.c() > stride_[0])
|| (extent.w() * stride_[0] > stride_[1])
|| (extent.h() * stride_[1] > stride_[2])
|| (extent.d() * stride_[2] > stride_[3])) {
assert(0);
}
return extent.n() * stride_[3];
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layout

View File

@ -81,17 +81,23 @@ struct TensorOpMultiplicand {
static int const kFactor =
kTileShapeContiguous * kElementsPerAccess / kCrosswise;
/// The strided dimension needs to be at least WarpSize(32) /
/// kTileShapeContiguous for a warp to access. To ensure conflict free
static_assert(
(kFactor > 0),
"kCrosswise should be no large than one shared memory cache line.");
/// The strided dimension needs to be at least (WarpSize(32) /
/// kTileShapeContiguous) for a warp to access. To ensure conflict free
/// access, it also needs to be at least (kTileShapeContiguous / kFactor).
/// See comments below
static int const kTileShapeStride =
((kTileShapeContiguous / kFactor) > (32 / kTileShapeContiguous))
? (kTileShapeContiguous / kFactor)
: (32 / kTileShapeContiguous);
/// Fundamental tile shape in units of vectors
/// For TN kblock=32 and 8x8x16 shapes, TileShape = <8, 4>.
/// For the rest, TileShape = <8, 8>
/// Fundamental tile shape in units of vectors to guarantee bank conflict free
/// shared memory load/store.
/// For kFactor = 1, TileShape = <8, 8>
/// For kFactor > 1, TileShape = <8, 4>
using TileShape = PitchLinearShape<kTileShapeContiguous, kTileShapeStride>;
/// Fundamental partition shape in units of vectors

14111
include/cutlass/matrix.h Normal file

File diff suppressed because it is too large Load Diff

View File

@ -515,17 +515,30 @@ struct NumericConverterClamp<T, float> {
using source_type = float;
static_assert((platform::is_same<result_type, int32_t>::value ||
platform::is_same<result_type, int16_t>::value ||
platform::is_same<result_type, uint16_t>::value ||
platform::is_same<result_type, int8_t>::value ||
platform::is_same<result_type, cutlass::int4b_t>::value),
platform::is_same<result_type, uint8_t>::value ||
platform::is_same<result_type, cutlass::int4b_t>::value ||
platform::is_same<result_type, cutlass::uint4b_t>::value),
"Clamp is only needed for integer types");
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & s) {
NumericConverter<result_type, double> convert_op;
double kClamp_max, kClamp_min;
double kClamp_max = double((1U << (sizeof_bits<result_type>::value - 1)) - 1);
double kClamp_min = -kClamp_max - 1;
if (platform::is_same<result_type, int32_t>::value ||
platform::is_same<result_type, int16_t>::value ||
platform::is_same<result_type, int8_t>::value ||
platform::is_same<result_type, cutlass::int4b_t>::value) {
kClamp_max = double((1LLU << (sizeof_bits<result_type>::value - 1)) - 1);
kClamp_min = -kClamp_max - 1;
} else {
kClamp_max = double((1LLU << (sizeof_bits<result_type>::value)) - 1);
kClamp_min = 0;
}
double source = s;
@ -946,6 +959,130 @@ struct NumericArrayConverter<int8_t, int, N, Round> {
return convert(s);
}
};
/// Partial specialization for Array<uint8_t, 1> <= Array<int, 1>
template <
FloatRoundStyle Round
>
struct NumericArrayConverter<uint8_t, int, 1, Round> {
using result_type = Array<uint8_t, 1>;
using source_type = Array<int, 1>;
static FloatRoundStyle const round_style = Round;
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {
NumericConverter<uint8_t, int, Round> convert_element_;
result_type result;
result[0] = convert_element_(source[0]);
return result;
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return convert(s);
}
};
/// Partial specialization for Array<uint8_t, 2> <= Array<int, 2>
template <
FloatRoundStyle Round
>
struct NumericArrayConverter<uint8_t, int, 2, Round> {
using result_type = Array<uint8_t, 2>;
using source_type = Array<int, 2>;
static FloatRoundStyle const round_style = Round;
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {
uint32_t tmp;
asm volatile(
"cvt.pack.sat.u8.s32.b32 %0, %2, %1, 0;\n"
: "=r"(tmp) : "r"(source[0]), "r"(source[1]));
uint16_t out = (tmp & 0xffff);
return reinterpret_cast<result_type const &>(out);
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return convert(s);
}
};
/// Partial specialization for Array<uint8_t, 4> <= Array<int, 4>
template <
FloatRoundStyle Round
>
struct NumericArrayConverter<uint8_t, int, 4, Round> {
using result_type = Array<uint8_t, 4>;
using source_type = Array<int, 4>;
static FloatRoundStyle const round_style = Round;
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {
unsigned out;
asm volatile(
"{ .reg .u32 r4;"
"cvt.pack.sat.u8.s32.b32 r4, %4, %3, 0;"
"cvt.pack.sat.u8.s32.b32 %0, %2, %1, r4;"
"}"
: "=r"(out) : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]));
return reinterpret_cast<result_type const &>(out);
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return convert(s);
}
};
/// Partial specialization for Array<int8_t> <= Array<int>
template <
int N,
FloatRoundStyle Round
>
struct NumericArrayConverter<uint8_t, int, N, Round> {
static_assert(!(N % 4), "N must be multiple of 4.");
using result_type = Array<uint8_t, N>;
using source_type = Array<int, N>;
static FloatRoundStyle const round_style = Round;
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {
NumericArrayConverter<uint8_t, int, 4, Round> convert_vector_;
result_type result;
Array<uint8_t, 4> *result_ptr = reinterpret_cast<Array<uint8_t, 4> *>(&result);
Array<int, 4> const *source_ptr = reinterpret_cast<Array<int, 4> const *>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return convert(s);
}
};
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -1025,12 +1162,84 @@ struct NumericArrayConverter<int4b_t, int, N, Round> {
}
};
/// Partial specialization for Array<uint4b_t, 8> <= Array<int, 8>
template <
FloatRoundStyle Round
>
struct NumericArrayConverter<uint4b_t, int, 8, Round> {
using result_type = Array<uint4b_t, 8>;
using source_type = Array<int, 8>;
static FloatRoundStyle const round_style = Round;
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {
unsigned out;
asm volatile(
"{ .reg .u32 r4;"
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
"}"
: "=r"(out)
: "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]),
"r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7]));
return reinterpret_cast<result_type const &>(out);
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return convert(s);
}
};
/// Partial specialization for Array<int4b_t> <= Array<int>
template <
int N,
FloatRoundStyle Round
>
struct NumericArrayConverter<uint4b_t, int, N, Round> {
static_assert(!(N % 8), "N must be multiple of 8.");
using result_type = Array<uint4b_t, N>;
using source_type = Array<int, N>;
static FloatRoundStyle const round_style = Round;
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {
NumericArrayConverter<uint4b_t, int, 8, Round> convert_vector_;
result_type result;
Array<uint4b_t, 8> *result_ptr = reinterpret_cast<Array<uint4b_t, 8> *>(&result);
Array<int, 8> const *source_ptr = reinterpret_cast<Array<int, 8> const *>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 8; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return convert(s);
}
};
#endif // Conditional guards to enable partial specialization for packed integers
/////////////////////////////////////////////////////////////////////////////////////////////////
/// FastNumericArrayConverter only works when the source is within center range.
/// Conversion operator for Array
/// Conversion operator for Array. See the comments before
/// FastLinearCombinationClamp.
template <typename T, typename S, int N,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
struct FastNumericArrayConverter {

View File

@ -0,0 +1,616 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a densely packed quaternion object intended for storing data in registers and
executing quaternion operations within a CUDA or host thread.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/matrix.h"
#include "cutlass/fast_math.h"
#include "cutlass/layout/vector.h"
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Quaternion: xi + yj + zk + w
template <
typename Element_ = float ///< element type
>
class Quaternion : public Array<Element_, 4> {
public:
/// Logical rank of tensor index space
static int const kRank = 1;
/// Number of elements
static int const kExtent = 4;
/// Base class is a four-element array
using Base = Array<Element_, kExtent>;
/// Element type
using Element = typename Base::Element;
/// Reference type to an element
using Reference = typename Base::reference;
/// Index type
using Index = int;
/// Quaternion storage - imaginary part
static int const kX = 0;
/// Quaternion storage - imaginary part
static int const kY = 1;
/// Quaternion storage - imaginary part
static int const kZ = 2;
/// Quaternion storage - real part
static int const kW = 3;
public:
//
// Methods
//
/// Constructs a quaternion
CUTLASS_HOST_DEVICE
Quaternion(
Element w_ = Element(1)
) {
Base::at(kX) = Element(0);
Base::at(kY) = Element(0);
Base::at(kZ) = Element(0);
Base::at(kW) = w_;
}
/// Constructs a quaternion
CUTLASS_HOST_DEVICE
Quaternion(
Element x_,
Element y_,
Element z_,
Element w_
) {
Base::at(kX) = x_;
Base::at(kY) = y_;
Base::at(kZ) = z_;
Base::at(kW) = w_;
}
/// Constructs a quaternion from a vector representing the imaginary part and a real number
CUTLASS_HOST_DEVICE
Quaternion(
Matrix3x1<Element> const &imag_,
Element w_ = Element()
) {
Base::at(kX) = imag_[0];
Base::at(kY) = imag_[1];
Base::at(kZ) = imag_[2];
Base::at(kW) = w_;
}
/// Returns a reference to the element at a given Coord
CUTLASS_HOST_DEVICE
Reference at(Index idx) const {
return Base::at(idx);
}
/// Returns a reference to the element at a given Coord
CUTLASS_HOST_DEVICE
Reference at(Index idx) {
return Base::at(idx);
}
/// Accesses the x element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Element x() const {
return Base::at(kX);
}
/// Accesses the x element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Reference x() {
return Base::at(kX);
}
/// Accesses the y element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Element y() const {
return Base::at(kY);
}
/// Accesses the y element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Reference y() {
return Base::at(kY);
}
/// Accesses the z element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Element z() const {
return Base::at(kZ);
}
/// Accesses the z element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Reference z() {
return Base::at(kZ);
}
/// Accesses the real part of the quaternion
CUTLASS_HOST_DEVICE
Element w() const {
return Base::at(kW);
}
/// Accesses the real part of the quaternion
CUTLASS_HOST_DEVICE
Reference w() {
return Base::at(kW);
}
/// Returns the pure imaginary part of the quaternion as a 3-vector
CUTLASS_HOST_DEVICE
Matrix3x1<Element> pure() const {
return Matrix3x1<Element>(x(), y(), z());
}
/// Returns a quaternion representation of a spatial rotation given a unit-length axis and
/// a rotation in radians.
CUTLASS_HOST_DEVICE
static Quaternion<Element> rotation(
Matrix3x1<Element> const &axis_unit, ///< axis of rotation (assumed to be unit length)
Element theta) { ///< angular rotation in radians
Element s = fast_sin(theta / Element(2));
return Quaternion(
s * axis_unit[0],
s * axis_unit[1],
s * axis_unit[2],
fast_cos(theta / Element(2))
);
}
/// Returns a quaternion representation of a spatial rotation represented as a
/// unit-length rotation axis (r_x, r_y, r_z) and an angular rotation in radians
CUTLASS_HOST_DEVICE
static Quaternion<Element> rotation(
Element r_x,
Element r_y,
Element r_z,
Element theta) { ///< angular rotation in radians
return rotation({r_x, r_y, r_z}, theta);
}
/// Geometric rotation of a 3-element vector
CUTLASS_HOST_DEVICE
Matrix3x1<Element> rotate(Matrix3x1<Element> const &rhs) const {
return (*this * Quaternion<Element>(rhs, 0) * reciprocal(*this)).pure();
}
/// Inverse rotation operation
CUTLASS_HOST_DEVICE
Matrix3x1<Element> rotate_inv(Matrix3x1<Element> const &rhs) const {
return (reciprocal(*this) * Quaternion<Element>(rhs, 0) * *this).pure();
}
/// Rotates a 3-vector assuming this is a unit quaternion (a spinor)
CUTLASS_HOST_DEVICE
Matrix3x1<Element> spinor(Matrix3x1<Element> const &rhs) const {
return (*this * Quaternion<Element>(rhs, 0) * conj(*this)).pure();
}
/// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor)
CUTLASS_HOST_DEVICE
Matrix3x1<Element> spinor_inv(Matrix3x1<Element> const &rhs) const {
return (conj(*this) * Quaternion<Element>(rhs, 0) * *this).pure();
}
/// In-place addition
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator+=(Quaternion<Element> const &rhs) {
*this = (*this + rhs);
return *this;
}
/// In-place subtraction
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator-=(Quaternion<Element> const &rhs) {
*this = (*this - rhs);
return *this;
}
/// In-place multiplication
template <typename T>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator*=(Quaternion<Element> const &rhs) {
*this = (*this * rhs);
return *this;
}
/// Scalar multiplication
template <typename T>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator*=(Element s) {
*this = (*this * s);
return *this;
}
/// In-place Division
template <typename T>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator/=(Quaternion<Element> const &rhs) {
*this = (*this / rhs);
return *this;
}
/// In-place Division
template <typename T>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator/=(Element s) {
*this = (*this / s);
return *this;
}
/// Computes a 3x3 rotation matrix (row-major representation)
CUTLASS_HOST_DEVICE
Matrix3x3<Element> as_rotation_matrix_3x3() const {
Matrix3x3<Element> m(
w() * w() + x() * x() - y() * y() - z() * z(),
2 * x() * y() - 2 * w() * z(),
2 * x() * z() + 2 * w() * y(),
2 * x() * y() + 2 * w() * z(),
w() * w() - x() * x() + y() * y() - z() * z(),
2 * y() * z() - 2 * w() * x(),
2 * x() * z() - 2 * w() * y(),
2 * y() * z() + 2 * w() * x(),
w() * w() - x() * x() - y() * y() + z() * z()
);
return m;
}
/// Computes a 4x4 rotation matrix (row-major representation)
CUTLASS_HOST_DEVICE
Matrix4x4<Element> as_rotation_matrix_4x4() const {
Matrix4x4<Element> m = Matrix4x4<Element>::identity();
m.set_slice_3x3(as_rotation_matrix_3x3());
return m;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Constructs a quaternion that is non-zero only in its real element.
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> make_Quaternion(
Element w) { ///< real part
return Quaternion<Element>(w);
}
/// Constructs a quaternion from a vector and real
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> make_Quaternion(
Matrix3x1<Element> const &imag, ///< imaginary party as a vector
Element w) { ///< real part
return Quaternion<Element>(imag, w);
}
/// Constructs a quaternion from a unit-length rotation axis and a rotation
/// angle in radians
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> make_QuaternionRotation(
Matrix3x1<Element> const &axis_unit, ///< rotation axis (unit-length)
Element w) { ///< rotation angle in radians
return Quaternion<Element>::rotation(axis_unit, w);
}
/// Constructs a quaternion q = xi + yj + zk + w
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> make_Quaternion(Element x, Element y, Element z, Element w) {
return Quaternion<Element>(x, y, z, w);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Returns the magnitude of the complex number
template <typename Element>
CUTLASS_HOST_DEVICE
Element abs(Quaternion<Element> const &q) {
return fast_sqrt(norm(q));
}
/// Quaternion conjugate
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> conj(Quaternion<Element> const &q) {
return make_Quaternion(
-q.x(),
-q.y(),
-q.z(),
q.w()
);
}
/// Computes the squared magnitude of the quaternion
template <typename Element>
CUTLASS_HOST_DEVICE
Element norm(Quaternion<Element> const &q) {
return q.x() * q.x() + q.y() * q.y() + q.z() * q.z() + q.w() * q.w();
}
/// Quaternion reciprocal
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> reciprocal(Quaternion<Element> const &q) {
Element nsq = norm(q);
return make_Quaternion(
-q.x() / nsq,
-q.y() / nsq,
-q.z() / nsq,
q.w() / nsq
);
}
/// Returns a unit-length quaternion
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> unit(Quaternion<Element> const &q) {
Element rcp_mag = Element(1) / abs(q);
return make_Quaternion(
q.x() * rcp_mag,
q.y() * rcp_mag,
q.z() * rcp_mag,
q.w() * rcp_mag
);
}
/// Quaternion exponential
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> exp(Quaternion<Element> const &q) {
Element exp_ = fast_exp(q.w());
Element imag_norm = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z());
Element sin_norm = fast_sin(imag_norm);
return make_Quaternion(
exp_ * q.x() * sin_norm / imag_norm,
exp_ * q.y() * sin_norm / imag_norm,
exp_ * q.z() * sin_norm / imag_norm,
exp_ * fast_cos(imag_norm)
);
}
/// Quaternion natural logarithm
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> log(Quaternion<Element> const &q) {
Element v = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z());
Element s = fast_acos(q.w() / abs(q)) / v;
return make_Quaternion(
q.x() * s,
q.y() * s,
q.z() * s,
fast_log(q.w())
);
}
/// Gets the rotation angle from a unit-length quaternion
template <typename Element>
CUTLASS_HOST_DEVICE
Element get_rotation_angle(Quaternion<Element> const &q_unit) {
return fast_acos(q_unit.w()) * Element(2);
}
/// Gets the rotation axis from a unit-length quaternion
template <typename Element>
CUTLASS_HOST_DEVICE
Matrix3x1<Element> get_rotation_axis(Quaternion<Element> const &q_unit) {
return q_unit.pure().unit();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Equality operator
template <typename Element>
CUTLASS_HOST_DEVICE
bool operator==(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return lhs.x() == rhs.x() &&
lhs.y() == rhs.y() &&
lhs.z() == rhs.z() &&
lhs.w() == rhs.w();
}
/// Inequality operator
template <typename Element>
CUTLASS_HOST_DEVICE
bool operator!=(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return !(lhs == rhs);
}
/// Quaternion scalar multiplication
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator*(Quaternion<Element> q, Element s) {
return make_Quaternion(
q.x() * s,
q.y() * s,
q.z() * s,
q.w() * s
);
}
/// Quaternion scalar multiplication
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator*(Element s, Quaternion<Element> const &q) {
return make_Quaternion(
s * q.x(),
s * q.y(),
s * q.z(),
s * q.w()
);
}
/// Quaternion scalar division
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator/(Quaternion<Element> const &q, Element s) {
return make_Quaternion(
q.x() / s,
q.y() / s,
q.z() / s,
q.w() / s
);
}
/// Quaternion unary negation
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator-(Quaternion<Element> const &q) {
return make_Quaternion(
-q.x(),
-q.y(),
-q.z(),
-q.w()
);
}
/// Quaternion addition
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator+(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return make_Quaternion(
lhs.x() + rhs.x(),
lhs.y() + rhs.y(),
lhs.z() + rhs.z(),
lhs.w() + rhs.w()
);
}
/// Quaternion subtraction
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator-(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return make_Quaternion(
lhs.x() - rhs.x(),
lhs.y() - rhs.y(),
lhs.z() - rhs.z(),
lhs.w() - rhs.w()
);
}
/// Quaternion product
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator*(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return make_Quaternion(
lhs.w() * rhs.x() + rhs.w() * lhs.x() + lhs.y() * rhs.z() - lhs.z() * rhs.y(),
lhs.w() * rhs.y() + rhs.w() * lhs.y() + lhs.z() * rhs.x() - lhs.x() * rhs.z(),
lhs.w() * rhs.z() + rhs.w() * lhs.z() + lhs.x() * rhs.y() - lhs.y() * rhs.x(),
lhs.w() * rhs.w() - lhs.x() * rhs.x() - lhs.y() * rhs.y() - lhs.z() * rhs.z()
);
}
/// Quaternion division
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator/(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return lhs * reciprocal(rhs);
}
/// Quaternion scalar division
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator/(Element s, Quaternion<Element> const &q) {
return s * reciprocal(q);
}
/// Rotates a 3-vector assuming this is a unit quaternion (a spinor). This avoids computing
/// a reciprocal.
template <typename Element>
CUTLASS_HOST_DEVICE
Matrix3x1<Element> spinor_rotation(
Quaternion<Element> const &spinor, /// unit-length quaternion
Matrix3x1<Element> const &rhs) { /// arbitrary 3-vector
return (spinor * Quaternion<Element>(rhs, 0) * conj(spinor)).pure();
}
/// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor). This avoids computing
/// a reciprocal.
template <typename Element>
CUTLASS_HOST_DEVICE
Matrix3x1<Element> spinor_rotation_inv(
Quaternion<Element> const &spinor, /// unit-length quaternion
Matrix3x1<Element> const &rhs) { /// arbitrary 3-vector
return (conj(spinor) * Quaternion<Element>(rhs, 0) * spinor).pure();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Output operators
//
template <typename Element>
std::ostream &operator<<(std::ostream &out, Quaternion<Element> const &q) {
return out << q.w() << "+i" << q.x() << "+j" << q.y() << "+k" << q.z();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -22,6 +22,11 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
\file
\brief This class provides helpers to support real<> and complex<> types in generic code.
*/
#pragma once
namespace cutlass {

View File

@ -1,179 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements a software-pipelined efficient batched reduction.
D = alpha * Reduction(A) + beta * C
*/
#pragma once
#if !defined(__CUDACC_RTC__)
#include <cuda.h>
#endif
#include "cutlass/coord.h"
#include "cutlass/util/platform.h"
#include "cutlass/fragment.h"
namespace cutlass {
namespace reduction {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename batched_reduction_>
__global__ __launch_bounds__(batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_::Params params) {
// Construct the batched_reduction object
batched_reduction_ batched_reduction(params);
batched_reduction.run();
}
template <typename BatchedReductionTraits_>
struct BatchedReduction {
/// This class
typedef BatchedReduction<BatchedReductionTraits_> This_;
/// The traits
typedef BatchedReductionTraits_ Traits;
/// Params
typedef typename Traits::Params Params;
/// functor
typedef typename Traits::Functor Functor;
/// ctor
CUTLASS_DEVICE BatchedReduction(Params const &params_)
: params(params_), functor(params_.functorParams) {}
/// main operation method
/// D = alpha * Reduction(A) + beta * C
CUTLASS_DEVICE void run() {
#if (__CUDA_ARCH__ >= 600)
// Swizzle the IDs of the block
typename Traits::BlockSwizzle block_swizzle;
Coord<3> threadblock_offset =
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::SubTile>());
int subTileSize = gridDim.x * Traits::SubTile::kW;
int tileSize = params.problem_size[1] * params.problem_size[2];
int subTileOffset = threadblock_offset[2] + threadIdx.x * Traits::ThreadShape::kW;
int subTileBase = 0;
typename Traits::ScalarA inRegs[Traits::maxInReg];
typename Traits::ScalarAccum AccumRegs[Traits::maxOutReg];
#pragma unroll
for (int subTile = 0; subTile < tileSize; subTile += subTileSize) {
int tileOffset = subTileBase + subTileOffset;
// Init AccumRegs
#pragma unroll
for (int i = 0; i < Traits::ThreadShape::kW; i++)
AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(0.0f);
// Fetch c0
typename Traits::ScalarAccum c0[Traits::ThreadShape::kW];
#pragma unroll
for (int i = 0; i< Traits::ThreadShape::kW; i++)
c0[i] = static_cast<typename Traits::ScalarAccum>(params.d_c[tileOffset + i]);
// Fetch partial sums from A
#pragma unroll
for (int s = 0; s < Traits::ReductionSize; s++) {
int inRegOffset = s * Traits::ThreadShape::kW;
int dOffset = (s * tileSize) + tileOffset;
#pragma unroll
for (int i = 0; i< Traits::ThreadShape::kW; i++) {
inRegs[inRegOffset + i] = params.d_a[dOffset + i];
}
}
// Accumulate
#pragma unroll
for (int s = 0; s < Traits::ReductionSize; s++) {
int inRegOffset = s * Traits::ThreadShape::kW;
#pragma unroll
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
//AccumRegs[i] = cuFma(params.alpha, inRegs[inRegOffset + i], AccumRegs[i]);
//AccumRegs[i] = params.alpha * inRegs[inRegOffset + i] + AccumRegs[i];
AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(inRegs[inRegOffset + i]) + AccumRegs[i];
}
}
// calling functor
functor_caller<Traits::ThreadShapeMultiple2>(AccumRegs, c0, AccumRegs);
// Store AccumRegs to D
#pragma unroll
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
params.d_d[tileOffset + i] = static_cast<typename Traits::ScalarD>(AccumRegs[i]);
}
// Advance sub-tile pointer
subTileBase += subTileSize;
} // end for loop
#endif //#if (__CUDA_ARCH__ >= 600)
}
template<bool ThreadShapeMultiple2>
CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output) {
if (ThreadShapeMultiple2 == true) {
#pragma unroll
for (int i = 0; i < Traits::ThreadShape::kW / 2; i++) {
functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 2>(&accum[2 * i], &old[2 * i], &output[2 * i]);
}
}
else {
#pragma unroll
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 1>(&accum[i], &old[i], &output[i]);
}
}
}
//
// Static function members
//
#if !defined(__CUDACC_RTC__)
/// Launch the kernel.
static __host__ cudaError_t launch(Params const& params,
cudaStream_t stream = cudaStreamDefault) {
// Setup the grid.
typename Traits::BlockSwizzle block_swizzle;
dim3 grid = block_swizzle.get_grid_layout(params.problem_size,
make_Coord_from_shape<typename Traits::OutputTile>());
dim3 block;
block.x = Traits::kThreads;
batched_reduction_kernel<This_><<<grid, block, 0, stream>>>(params);
return cudaGetLastError();
}
#endif
//
// Data members
//
/// The params.
Params const& params;
// The functor.
Functor functor;
};
} // namespace reduction
} // namespace cutlass

View File

@ -1,192 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties of complete batched reduction.
D = alpha * Reduction(A) + beta * C
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/shape.h"
#include "cutlass/reduction/threadblock_swizzle.h"
#include "cutlass/reduction/batched_reduction.h"
#include "cutlass/gemm/linear_scaling.h"
namespace cutlass {
namespace reduction {
/*
OutputTile defines the work load per thread block
Subtile defines the work load per thread block per iteration
OutputTile / Subtile = number of iterations within a kernel
ThreadShape defines the work load per thread
Subtile / ThreadShape = number of threads per thread block
*/
template <
/// The scalar type for A
typename ScalarA_,
/// The scalar type for C
typename ScalarC_,
/// The scalar type for D
typename ScalarD_,
/// the scalar type for alpha,
typename ScalarAlphaBeta_,
/// The scalar type for accumulator
typename ScalarAccum_,
/// Reduction work load per batch
int ReductionSize_ = 1,
/// The output tile, work load per thread block,
typename OutputTile_ = Shape<1, 1, 128>,
/// The subtile
typename SubTile_ = Shape<1, 1, 64>,
/// Work load per thread, per subtile
typename ThreadShape_ = Shape<1, 1, 2>,
/// The index
typename Index_ = int,
/// The block swizzle to reorganize the grid.
typename BlockSwizzle_ = DefaultBlockSwizzle,
/// The input register vector size in kernel
int maxInReg_ = 160,
/// The output register vector size in kernel
int maxOutReg_ = 64,
/// The functor that will be executed at the end
typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >
>
struct BatchedReductionTraits {
///
typedef BatchedReductionTraits<ScalarA_,
ScalarC_,
ScalarD_,
ScalarAlphaBeta_,
ScalarAccum_,
ReductionSize_,
OutputTile_,
SubTile_,
ThreadShape_,
Index_,
BlockSwizzle_,
maxInReg_,
maxOutReg_,
Functor_> This_;
/// The struct that consumes this Traits
typedef typename cutlass::reduction::BatchedReduction<This_> KernelClass;
///
typedef OutputTile_ OutputTile;
///
typedef SubTile_ SubTile;
///
typedef ThreadShape_ ThreadShape;
/// The input pointer type
typedef ScalarA_ ScalarA;
///
typedef ScalarC_ ScalarC;
/// The output pointer type
typedef ScalarD_ ScalarD;
/// The alpha beta type
typedef ScalarAlphaBeta_ ScalarAlphaBeta;
/// The type for accumulation
typedef ScalarAccum_ ScalarAccum;
/// The index
typedef Index_ Index;
/// The thread block swizzle
typedef BlockSwizzle_ BlockSwizzle;
///
static const int ReductionSize = ReductionSize_;
/// check if threadShape is multiple of 2.
static const bool ThreadShapeMultiple2 = (ThreadShape::kW % 2 == 0);
///
typedef Functor_ Functor;
/// Parameteres object constructable on the host
/// The number of threads per thread block. can be deduced
static int const kThreads = SubTile::kW / ThreadShape::kW;
//
static int const maxInReg = maxInReg_;
//
static int const maxOutReg = maxOutReg_;
//
static_assert(SubTile::kW % ThreadShape::kW == 0, "cannot evenly distribute work load among threads");
//
static_assert(kThreads % 32 == 0, "threads per threadblock is not multiple of 32");
//
static_assert(OutputTile::kW % SubTile::kW == 0, "cannot evenly distribute work load among iterations");
//
static_assert(ReductionSize * ThreadShape::kW <= maxInReg, "ReductionSize * ThreadShape::kW should not be bigger than maxInReg");
//
static_assert(ThreadShape::kW <= maxOutReg, "ThreadShape::kW should not be bigger than maxOutReg");
struct Params {
/// The dimension of output tensor
Coord<3> problem_size;
/// The alpha
ScalarAlphaBeta alpha;
/// The beta
ScalarAlphaBeta beta;
/// stride between two element that will be sumed
long long int reduction_stride;
//
ScalarA const *d_a;
//
Index lda;
//
ScalarC const *d_c;
//
Index ldc;
//
ScalarD *d_d;
//
Index ldd;
/// The functor params.
typename Functor::Params functorParams;
/// Initialize the parameters for 2D output tensor
CUTLASS_HOST_DEVICE int initialize(Index m_,
Index n_,
ScalarAlphaBeta alpha_,
ScalarAlphaBeta beta_,
long long int reduction_stride_,
ScalarA const *d_a_,
Index lda_,
ScalarC const *d_c_,
Index ldc_,
ScalarD *d_d_,
Index ldd_){
problem_size = make_Coord(1, n_, m_);
alpha = alpha_;
beta = beta_;
reduction_stride = reduction_stride_;
d_a = d_a_;
lda = lda_;
d_c = d_c_;
d_d = d_d_;
ldc = ldc_;
ldd = ldd_;
functorParams.initialize(alpha_, beta_);
return 0;
}
};
};
} // namespace reduction
} // namespace cutlass

View File

@ -77,6 +77,18 @@ bool relatively_equal<uint1b_t>(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) {
return (a == b);
}
template <>
CUTLASS_HOST_DEVICE
bool relatively_equal<int2b_t>(int2b_t a, int2b_t b, int2b_t, int2b_t) {
return (a == b);
}
template <>
CUTLASS_HOST_DEVICE
bool relatively_equal<uint2b_t>(uint2b_t a, uint2b_t b, uint2b_t, uint2b_t) {
return (a == b);
}
template <>
CUTLASS_HOST_DEVICE
bool relatively_equal<int4b_t>(int4b_t a, int4b_t b, int4b_t, int4b_t) {

Some files were not shown because too many files have changed in this diff Show More