CUTLASS 2.3 initial commit (#134)
CUTLASS 2.3 adds GEMMs targeting Sparse Tensor Cores on the NVIDIA Ampere Architecture, fast SGEMM, and small matrix classes, bug fixes, and performance enhancements.
This commit is contained in:
14
CHANGELOG.md
14
CHANGELOG.md
@ -2,6 +2,20 @@
|
||||
|
||||
# CUTLASS 2.x
|
||||
|
||||
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
* [Sparse Tensor Core GEMM kernels](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu):
|
||||
* Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Fast SGEMM targeting GeForce RTX 30-series CUDA Cores
|
||||
* Minor Features:
|
||||
* [Activation functions](/include/cutlass/epilogue/thread/activation.h) such as [GeLU](/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](/include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
|
||||
* Small [matrix](/include/cutlass/matrix.h) and [quaternion](/include/cutlass/quaternion.h) template classes in device code
|
||||
* [Floating-point constants](/include/cutlass/constants.h)
|
||||
* NVIDIA Ampere GPU Architecture examples and documentation:
|
||||
* [Tensor Float 32](/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](/media/docs/gemm_api.md#efficient-epilogue)
|
||||
|
||||
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
* Fast Tensor Core operations:
|
||||
|
||||
@ -32,7 +32,7 @@ endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.2.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.3.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
@ -69,6 +69,8 @@ endif()
|
||||
|
||||
set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples")
|
||||
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
|
||||
set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_TOOLS} CACHE BOOL "Enable CUTLASS Library")
|
||||
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_TOOLS} CACHE BOOL "Enable CUTLASS Profiler")
|
||||
|
||||
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT})
|
||||
@ -101,6 +103,9 @@ endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.1)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86)
|
||||
endif()
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
|
||||
|
||||
@ -164,12 +169,14 @@ set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.
|
||||
#
|
||||
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
|
||||
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
|
||||
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.")
|
||||
|
||||
|
||||
# Test Levels L0, L1, L2
|
||||
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
|
||||
set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
|
||||
#
|
||||
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
|
||||
@ -181,6 +188,11 @@ else()
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON)
|
||||
endif()
|
||||
|
||||
# Trace levels for debugging
|
||||
set(CUTLASS_DEBUG_TRACE_LEVEL "0" CACHE STRING "Level of debug tracing to perform.")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_DEBUG_TRACE_LEVEL=${CUTLASS_DEBUG_TRACE_LEVEL})
|
||||
|
||||
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
|
||||
"Enable PTX mma instruction for collective matrix multiply operations.")
|
||||
|
||||
@ -352,7 +364,7 @@ set_target_properties(CUTLASS PROPERTIES EXPORT_NAME cutlass)
|
||||
|
||||
set(CUTLASS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library")
|
||||
|
||||
set(CUTLASS_GENERATOR_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/library/)
|
||||
set(CUTLASS_GENERATOR_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/library CACHE INTERNAL "Location of generator scripts")
|
||||
|
||||
# The following utility directory is needed even if the tools build is disabled, so it exists here.
|
||||
set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/include CACHE INTERNAL "")
|
||||
|
||||
@ -16,6 +16,9 @@ Naila Farooqui
|
||||
Piotr Majcher
|
||||
Paul Springer
|
||||
Jin Wang
|
||||
Aniket Shivam
|
||||
Chinmay Talegaonkar
|
||||
Shang Zhang
|
||||
Scott Yokim
|
||||
Markus Hohnerbach
|
||||
Aditya Atluri
|
||||
@ -52,6 +55,8 @@ Olivier Giroux
|
||||
Stephen Jones
|
||||
Rishkul Kulkarni
|
||||
Bryce Lelbach
|
||||
Matthew Nicely
|
||||
Joel McCormack
|
||||
Kyrylo Perelygin
|
||||
|
||||
|
||||
|
||||
@ -213,7 +213,14 @@ function(cutlass_correct_source_file_language_property)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation")
|
||||
# If building with all kernels, set UNITY build on by default.
|
||||
if (CUTLASS_LIBRARY_KERNELS MATCHES "all")
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON)
|
||||
else()
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT OFF)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation")
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
|
||||
|
||||
function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
|
||||
41
README.md
41
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 2.2
|
||||
# CUTLASS 2.3
|
||||
|
||||
_CUTLASS 2.2 - June 2020_
|
||||
_CUTLASS 2.3 - September 2020_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
@ -30,6 +30,14 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
See the [functionality listing](media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.3
|
||||
|
||||
CUTLASS 2.3 is a minor update to CUTLASS adding:
|
||||
- GEMMs targeting structured [Sparse Tensor Cores](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) in NVIDIA Ampere Architecture GPUs
|
||||
- Fast SGEMM kernels targeting GeForce RTX 30-series CUDA Cores
|
||||
- Intended to be compiled with [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
- See the [CHANGELOG](CHANGELOG.md) for more details.
|
||||
|
||||
# What's New in CUTLASS 2.2
|
||||
|
||||
CUTLASS 2.2 is a significant update to CUTLASS adding:
|
||||
@ -42,7 +50,7 @@ CUTLASS 2.2 is a significant update to CUTLASS adding:
|
||||
|
||||
# What's New in CUTLASS 2.1
|
||||
|
||||
CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
|
||||
CUTLASS 2.1 is a minor update to CUTLASS adding:
|
||||
|
||||
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
|
||||
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
@ -71,8 +79,8 @@ using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++11 host compiler and
|
||||
performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2.
|
||||
performs best when built with the [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, CUDA 10.2, and CUDA 11.0.
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
@ -99,10 +107,11 @@ any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GP
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|
||||
|NVIDIA Tesla T4|7.5|10.0|10.2|
|
||||
|NVIDIA A100|8.0|11.0|11.0|
|
||||
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS 2.2 is described in the following documents and the accompanying
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
|
||||
@ -136,14 +145,14 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
|
||||
```
|
||||
|
||||
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, and 8.6. To reduce compile time you can specify
|
||||
the architectures to build CUTLASS for by changing the CMake configuration setting
|
||||
`CUTLASS_NVCC_ARCHS`.
|
||||
|
||||
```
|
||||
$ mkdir build && cd build
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 # compiles for NVIDIA's Turing GPU architecture
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA's Ampere Architecture
|
||||
```
|
||||
|
||||
From the `build/` directory, compile and run the CUTLASS unit tests by building the target `test_unit` with make.
|
||||
@ -258,15 +267,25 @@ The `tools/profiler/` directory contains a command-line utility for launching ea
|
||||
It can be built as follows:
|
||||
|
||||
```
|
||||
$ make cutlass_profiler -j
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
To limit compilation time, only one tile size is instantiated for each data type, math instruction, and layout.
|
||||
By default, only one tile size is instantiated for each data type, math instruction, and layout.
|
||||
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
|
||||
Beware, this results in *thousands* of kernels and long build times.
|
||||
```
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all
|
||||
...
|
||||
$ make cutlass_profiler -j
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with
|
||||
wildcard characters may be reduce the set of kernels. The following builds exactly one kernel:
|
||||
|
||||
```
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sgemm_128x128_8x2_nn_align1
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling SGEMM kernels is as follows:
|
||||
|
||||
@ -69,7 +69,7 @@
|
||||
template <typename Element, typename GmemIterator, typename SmemIterator>
|
||||
__global__ void kernel_dump(typename GmemIterator::Params params,
|
||||
typename GmemIterator::TensorRef ref) {
|
||||
__shared__ Element shared_storage[EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL];
|
||||
extern __shared__ Element shared_storage[];
|
||||
|
||||
// Construct the global iterator and load the data to the fragments.
|
||||
int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
@ -164,8 +164,11 @@ int main() {
|
||||
dim3 grid(1, 1);
|
||||
dim3 block(32, 1, 1);
|
||||
|
||||
int smem_size =
|
||||
int(sizeof(Element) * EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL);
|
||||
|
||||
kernel_dump<Element, GmemIterator, SmemIterator>
|
||||
<<<grid, block>>>(params, matrix.device_ref());
|
||||
<<<grid, block, smem_size, 0>>>(params, matrix.device_ref());
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@
|
||||
To build strictly the planar complex kernels needed for general application, execute the following
|
||||
CMake command in an empty build directory.
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex
|
||||
|
||||
This builds all planar complex GEMM variants for Volta and Turing architectures.
|
||||
@ -59,7 +59,7 @@
|
||||
specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for
|
||||
the 'CN' layout configuration (conjugate A operand with both A and B as column-major).
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_f16*cn
|
||||
|
||||
$ make 10_planar_complex
|
||||
@ -526,6 +526,11 @@ int main(int argc, char const **args) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond.
|
||||
//
|
||||
// fall through
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
|
||||
@ -48,7 +48,7 @@
|
||||
To build strictly the planar complex kernels needed for general application, execute the following
|
||||
CMake command in an empty build directory.
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex
|
||||
|
||||
This builds all planar complex GEMM variants for Volta and Turing architectures.
|
||||
@ -57,7 +57,7 @@
|
||||
specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for
|
||||
the 'CN' layout configuration (conjugate A operand with both A and B as column-major).
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_array_f16*cn
|
||||
|
||||
$ make 11_planar_complex_array
|
||||
@ -586,6 +586,11 @@ int main(int argc, char const **args) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond.
|
||||
//
|
||||
// fall through
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
|
||||
@ -0,0 +1,205 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_nonfused_gemm_s8_sm80() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
true
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
true
|
||||
>;
|
||||
|
||||
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
}
|
||||
|
||||
void run_fused_gemm_s8_sm80() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
true
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
@ -38,6 +38,8 @@
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@ -115,7 +117,9 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true) {
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
@ -232,6 +236,13 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
status = gemm_op_1.initialize(arguments_1);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = gemm_op_0();
|
||||
CUTLASS_CHECK(status);
|
||||
status = gemm_op_1();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
@ -242,14 +253,14 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
@ -261,9 +272,9 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
||||
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n";
|
||||
std::cout << "total time " << totalTime / 100.0 << " ms\n";
|
||||
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
||||
std::cout << "total time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
@ -302,7 +313,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
tensor_D0.device_ref(),
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
@ -420,7 +431,9 @@ struct B2bInterleavedFusedGemmRun
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true) {
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
@ -478,7 +491,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
|
||||
//Reorder B0
|
||||
cutlass::reorder_column<B2bGemm::InstructionShape::kK>(
|
||||
cutlass::reorder_column<16>(
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
|
||||
cutlass::reorder_column<InterleavedK_>(
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
|
||||
@ -526,6 +539,11 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
@ -536,7 +554,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = b2b_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
@ -546,7 +564,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "time " << gemmTime / 100.0 << " ms\n";
|
||||
std::cout << "time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
//tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
@ -30,7 +30,6 @@ two unfused GEMM operations, demonstrating a speedup of the fused kernel on the
|
||||
NVIDIA Turing GPU architecture.
|
||||
|
||||
Problem size:
|
||||
|
||||
GEMM1 (M,N,K): 128*1600, 64, 576
|
||||
GEMM2 (M,N,K): 128*1600, 128, 64
|
||||
|
||||
@ -42,16 +41,17 @@ also requires warp_tile_N = thread_block_tile_N so the data required by each war
|
||||
register-file-resident.
|
||||
|
||||
Performance:
|
||||
|
||||
- fp16 on Tesla T4 @ 1590MHz (non-fused vs. fused): 1.39011 ms vs. 1.26035 ms
|
||||
- int8 on Tesla T4 @ 1590MHz (non-fused vs. fused): 0.751759 ms vs. 0.62971 ms
|
||||
- fp16 on Quadro RTX 8000 @ 1890MHz (non-fused vs. fused): 0.721144 ms vs. 0.629864 ms
|
||||
- int8 on Quadro RTX 8000 @ 1890MHz (non-fused vs. fused): 0.379049 ms vs. 0.324764 ms
|
||||
- int8 on GA100 @ 1200MHz (non-fused vs. fused): 0.153795 ms vs. 0.129874 ms
|
||||
|
||||
*/
|
||||
|
||||
#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h"
|
||||
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h"
|
||||
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h"
|
||||
|
||||
int run() {
|
||||
|
||||
@ -71,7 +71,10 @@ int run() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
run_nonfused_gemm_s8_sm80();
|
||||
run_fused_gemm_s8_sm80();
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
run_nonfused_gemm_f16();
|
||||
run_fused_gemm_f16();
|
||||
run_nonfused_gemm_s8();
|
||||
|
||||
@ -210,7 +210,8 @@ struct B2bGemm {
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
@ -313,7 +314,8 @@ struct B2bGemm {
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
|
||||
@ -217,7 +217,85 @@ struct DefaultB2bGemm<
|
||||
};
|
||||
|
||||
|
||||
/// Partial specialization for Turing IMMA Interleaved layout
|
||||
/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Is Beta zero or not
|
||||
bool IsBetaZero>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator, IsBetaZero> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0,
|
||||
true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK,
|
||||
IsBetaZero>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Partial specialization for Turing Integer Tensor Core Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
|
||||
862
examples/13_fused_two_gemms/threadblock/b2b_mma_multistage.h
Normal file
862
examples/13_fused_two_gemms/threadblock/b2b_mma_multistage.h
Normal file
@ -0,0 +1,862 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA0,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB0,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB1,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bMmaMultistage :
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over intermediate accumulator tile
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations0 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
static_assert(Base::kWarpGemmIterations1 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const TBLDGSTSIterationsA0 =
|
||||
IteratorA0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB0 =
|
||||
IteratorB0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB1 =
|
||||
IteratorB1::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA0 =
|
||||
(TBLDGSTSIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB0 =
|
||||
(TBLDGSTSIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB1 =
|
||||
(TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0,
|
||||
int group_start_A0 = 0, int group_start_B0 = 0) {
|
||||
iterator_A0.set_iteration_index(group_start_A0 *
|
||||
IteratorA0::kAccessesPerVector);
|
||||
this->smem_iterator_A0_.set_iteration_index(group_start_A0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) {
|
||||
if (group_start_A0 + j < Detail::TBLDGSTSIterationsA0) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess /
|
||||
IteratorA0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A0.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B0 *
|
||||
IteratorB0::kAccessesPerVector);
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) {
|
||||
if (group_start_B0 + j < Detail::TBLDGSTSIterationsB0) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B0.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(IteratorB1 &iterator_B1,
|
||||
int group_start_B1 = 0) {
|
||||
iterator_B1.set_iteration_index(group_start_B1 *
|
||||
IteratorB1::kAccessesPerVector);
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
|
||||
if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B1.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
///< epilogue operation after 1st Gemm
|
||||
OutputOp output_op_0)
|
||||
{
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
if (gemm_k_iterations_0 == 0) {
|
||||
iterator_A0.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsA0; ++j) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess /
|
||||
IteratorA0::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(0);
|
||||
this->smem_iterator_B0_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB0; ++j) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA0 warp_loaded_frag_A0[2];
|
||||
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
||||
WarpTransformedFragmentA0 warp_transformed_frag_A0[2];
|
||||
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (gemm_k_iterations_0 == 0) {
|
||||
iterator_A0.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > (-Base::kStages + 1);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A0[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
|
||||
warp_mma0(
|
||||
accum0,
|
||||
warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations0 - 1) {
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
|
||||
group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0;
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations0) {
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
group_start_iteration_A0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB0;
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
if (gemm_k_iterations_0 == 0) {
|
||||
iterator_A0.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
// 2nd Gemm
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
if (gemm_k_iterations_1 == 0) {
|
||||
// iterator_A1.clear_mask();
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
|
||||
#if 0
|
||||
iterator_A1.set_iteration_index(0);
|
||||
this->smem_iterator_A1_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsA1; ++j) {
|
||||
typename IteratorA1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA1::AccessType *>(
|
||||
this->smem_iterator_A1_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA1::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA1::Element>::value *
|
||||
IteratorA1::ThreadMap::kElementsPerAccess /
|
||||
IteratorA1::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
#endif
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
//iterator_A1.add_tile_offset({0, 1});
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
//this->smem_iterator_A1_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
// FragmentC0 accum0 = src_accum;
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
// this->warp_tile_iterator_A1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (gemm_k_iterations_1 == 0) {
|
||||
// iterator_A1.clear_mask();
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
// this->warp_tile_iterator_A1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
|
||||
warp_mma1(
|
||||
accum,
|
||||
warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
int group_start_iteration_B1;
|
||||
|
||||
group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1;
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
|
||||
int group_start_iteration_B1;
|
||||
group_start_iteration_B1 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
// --gemm_k_iterations_1;
|
||||
if (gemm_k_iterations_1 == 1) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -48,10 +48,6 @@ namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<int a>
|
||||
struct chk_val {
|
||||
static_assert(a==0, "check value");
|
||||
};
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_pipelined.h"
|
||||
#include "threadblock/b2b_mma_multistage.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -200,8 +201,6 @@ template <
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
@ -220,7 +219,7 @@ template <
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, true> {
|
||||
// Define the MmaCore components
|
||||
@ -251,7 +250,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>, ElementB,
|
||||
LayoutB, 0, typename MmaCore0::IteratorThreadMapB>;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
// Use fragment iterator for A1 operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
@ -282,6 +281,111 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for column-major-interleaved output
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
|
||||
Operator, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
|
||||
Operator, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>;
|
||||
|
||||
// Use fragment iterator for A1 operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout,
|
||||
InstructionShape, EpilogueOutputOp, true /*only handle beta=0 for 1st Gemm epilogue*/>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>;
|
||||
|
||||
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
27
examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt
Normal file
27
examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
14_ampere_tf32_tensorop_gemm
|
||||
ampere_tf32_tensorop_gemm.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,278 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
Please check example 07 and 08 for the basics of tensor op gemm kernels. On NVIDIA Ampere
|
||||
architecture, most concept still holds. The two main differences are
|
||||
|
||||
1. NVIDIA Ampere architecture introduces a new series of tensor core instructions (see
|
||||
include/cutlass/arch/mma_sm80.h) which are more efficient on Ampere.
|
||||
|
||||
2. NVIDIA Ampere architecture uses cp_async() to build multistage software pipeline to better hide
|
||||
latency (see include/cutlass/gemm/threadblock/mma_multistage.h)
|
||||
|
||||
Moreover, NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h)
|
||||
data types in tensor cores. One big advantage is that we can load in fp32 data and convert them
|
||||
implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate traditional
|
||||
fp32 data by using NVIDIA Ampere architecture.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output matrices and computation between
|
||||
// elements in input matrices.
|
||||
using ElementAccumulator = float; // <- data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
|
||||
using ElementInputA = float; // <- data type of elements in input matrix A
|
||||
using ElementInputB = float; // <- data type of elements in input matrix B
|
||||
using ElementOutput = float; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Column Major for
|
||||
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
|
||||
// memory access. For a byte, it's 16
|
||||
// elements. This becomes the vector width of
|
||||
// math instructions in the epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 4;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
int run() {
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
|
||||
// Return 0 so tests are considered passing if run on unsupported platforms.
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int length_m = 5120;
|
||||
const int length_n = 4096;
|
||||
const int length_k = 4096;
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// reference kernel
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(4),
|
||||
ElementInputA(-4),
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(4),
|
||||
ElementInputB(-4),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(4),
|
||||
ElementOutput(-4),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||
tensor_c.device_ref(), // <- reference to matrix C on device
|
||||
tensor_d.device_ref(), // <- reference to matrix D on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
cutlass::reference::device::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue>
|
||||
gemm_device;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_device(problem_size,
|
||||
alpha,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
beta,
|
||||
tensor_c.device_ref(),
|
||||
tensor_ref_d.device_ref());
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d.sync_host();
|
||||
tensor_ref_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
|
||||
// Returning zero so this test passes when built on older Toolkits.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return run();
|
||||
}
|
||||
}
|
||||
27
examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt
Normal file
27
examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
15_ampere_sparse_tensorop_gemm
|
||||
ampere_sparse_tensorop_gemm.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,311 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere
|
||||
architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4.
|
||||
|
||||
Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of
|
||||
meta data is different for every data types. CUTLASS templates can automatically infer it based on
|
||||
input A and B. Check code below.
|
||||
|
||||
Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers
|
||||
efficiently.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_sparse.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/host_uncompress.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output matrices and computation between
|
||||
// elements in input matrices.
|
||||
using ElementAccumulator = int32_t; // <- data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
|
||||
using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A
|
||||
using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B
|
||||
using ElementOutput = int32_t; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Column Major for
|
||||
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<256, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
|
||||
// memory access. For a byte, it's 16
|
||||
// elements. This becomes the vector width of
|
||||
// math instructions in the epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
using Gemm = cutlass::gemm::device::SparseGemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
||||
using ElementInputE = typename Gemm::ElementE;
|
||||
using LayoutInputE = typename Gemm::LayoutE;
|
||||
|
||||
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
||||
// 50% Sparsity on Ampere
|
||||
constexpr int kSparse = Gemm::kSparse;
|
||||
// How many elements of A are covered per ElementE
|
||||
constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
|
||||
// The size of individual meta data
|
||||
constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
|
||||
|
||||
int run() {
|
||||
|
||||
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
|
||||
// Return 0 so tests are considered passing if run on unsupported platforms.
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int length_m = 512;
|
||||
const int length_n = 512;
|
||||
const int length_k = 1024;
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2)
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a_uncompressed(
|
||||
problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing
|
||||
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// reference kernel
|
||||
|
||||
// Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing.
|
||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
// Same size as the above. The above one needs to be reordered and stored in this one.
|
||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e_reordered(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(1),
|
||||
ElementInputA(-1),
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(1),
|
||||
ElementInputB(-1),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(1),
|
||||
ElementOutput(-1),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomSparseMeta(
|
||||
tensor_e.host_view(),
|
||||
1,
|
||||
kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core
|
||||
// instructions.
|
||||
cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(),
|
||||
{problem_size.m(), problem_size.n(),
|
||||
problem_size.k() / kSparse / kElementsPerElementE});
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_e_reordered.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||
tensor_c.device_ref(), // <- reference to matrix C on device
|
||||
tensor_d.device_ref(), // <- reference to matrix D on device
|
||||
tensor_e.device_ref(), // <- reference to matrix E on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// uncompress tensor_a based on meta data tensor_e. We need it for reference computing.
|
||||
cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(),
|
||||
tensor_e.host_ref(), problem_size.m(), problem_size.k());
|
||||
|
||||
// Create instantiation for host reference gemm kernel
|
||||
cutlass::reference::host::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
typename Gemm::Operator>
|
||||
gemm_host;
|
||||
|
||||
// Launch host reference gemm kernel
|
||||
gemm_host(problem_size,
|
||||
alpha,
|
||||
tensor_a_uncompressed.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
beta,
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_d.host_ref());
|
||||
|
||||
// Copy output data from CUTLASS host for comparison
|
||||
tensor_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
|
||||
|
||||
// Returning zero so this test passes when built on older Toolkits.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return run();
|
||||
}
|
||||
}
|
||||
@ -71,6 +71,8 @@ foreach(EXAMPLE
|
||||
11_planar_complex_array
|
||||
12_gemm_bias_relu
|
||||
13_fused_two_gemms
|
||||
14_ampere_tf32_tensorop_gemm
|
||||
15_ampere_sparse_tensorop_gemm
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -55,6 +55,17 @@ struct Sm75 {
|
||||
struct Sm80 {
|
||||
static int const kMinComputeCapability = 80;
|
||||
};
|
||||
struct Sm86 {
|
||||
static int const kMinComputeCapability = 86;
|
||||
};
|
||||
|
||||
/// Triggers a breakpoint on the device
|
||||
CUTLASS_DEVICE
|
||||
void device_breakpoint() {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
asm volatile (" brkpt;\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -51,6 +51,8 @@ struct global_load;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The redundant mov PTX instruction is used to enforce the compiler to
|
||||
// initialize data to zero before ld.global
|
||||
template <typename AccessType
|
||||
>
|
||||
struct global_load<AccessType,
|
||||
@ -83,7 +85,6 @@ struct global_load<AccessType,
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
struct global_load<AccessType,
|
||||
|
||||
@ -150,6 +150,42 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, El
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specifies internal data type for computation
|
||||
struct SPFormatType {
|
||||
enum Kind {
|
||||
Thread
|
||||
};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Matrix multiply-add operation
|
||||
template <
|
||||
/// Size of the matrix product (concept: GemmShape)
|
||||
typename Shape_,
|
||||
/// Number of threads participating
|
||||
int kThreads_,
|
||||
/// Data type of A elements
|
||||
typename ElementA,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA,
|
||||
/// Data type of B elements
|
||||
typename ElementB,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Inner product operator
|
||||
typename Operator,
|
||||
/// Specifies meta data format
|
||||
SPFormatType::Kind SPFormat = SPFormatType::Thread
|
||||
>
|
||||
struct SparseMma;
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
||||
@ -165,4 +201,5 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, El
|
||||
#include "cutlass/arch/mma_sm70.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm80.h"
|
||||
#include "cutlass/arch/sp_mma_sm80.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -53,6 +53,7 @@ template <
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -79,6 +80,7 @@ template <
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -106,6 +108,7 @@ template <
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -142,6 +145,7 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -181,6 +185,7 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -218,6 +223,7 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -255,6 +261,7 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -292,6 +299,7 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -327,6 +335,7 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAddComplex;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -355,7 +364,8 @@ template <
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
Array<float, 1> &d,
|
||||
|
||||
@ -55,6 +55,7 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<2, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -99,6 +100,7 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 2, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -143,6 +145,7 @@ struct Mma <
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<2, 2, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -196,7 +199,8 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<2, 2, 1>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
Array<half_t, 4> &d,
|
||||
|
||||
@ -51,7 +51,8 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 4>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
Array<int, 1> &d,
|
||||
@ -98,6 +99,7 @@ struct Mma<
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 2>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
|
||||
@ -723,7 +723,6 @@ struct Mma<
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE
|
||||
|
||||
@ -85,7 +85,7 @@ Array<T, N> mac(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c
|
||||
Array<T, N> d;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
d[i] = a[i] * b[i] + c;
|
||||
d[i] = a[i] * b[i] + c[i];
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
1591
include/cutlass/arch/sp_mma_sm80.h
Normal file
1591
include/cutlass/arch/sp_mma_sm80.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -487,6 +487,46 @@ public:
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<Element, 1> make_Array(Element x) {
|
||||
Array<Element, 1> m;
|
||||
m[0] = x;
|
||||
return m;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<Element, 2> make_Array(Element x, Element y) {
|
||||
Array<Element, 2> m;
|
||||
m[0] = x;
|
||||
m[1] = y;
|
||||
return m;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<Element, 3> make_Array(Element x, Element y, Element z) {
|
||||
Array<Element, 3> m;
|
||||
m[0] = x;
|
||||
m[1] = y;
|
||||
m[2] = z;
|
||||
return m;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<Element, 4> make_Array(Element x, Element y, Element z, Element w) {
|
||||
Array<Element, 4> m;
|
||||
m[0] = x;
|
||||
m[1] = y;
|
||||
m[2] = z;
|
||||
m[3] = w;
|
||||
return m;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -65,7 +65,7 @@ struct alignas(2) bfloat16_t {
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t() { }
|
||||
bfloat16_t() : storage(0) { }
|
||||
|
||||
/// Floating-point conversion - round toward nearest
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -187,10 +187,12 @@ class complex
|
||||
/// Division
|
||||
template <typename A>
|
||||
CUTLASS_HOST_DEVICE complex<T> operator/(complex<A> const &rhs) const {
|
||||
T d = (rhs.real() * (rhs) + rhs.imag() * rhs.imag());
|
||||
T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag());
|
||||
|
||||
return complex<T>((this->real() * (rhs) + this->imag() * rhs.imag()) / d,
|
||||
(this->imag() * (rhs)-this->real() * rhs.imag()) / d);
|
||||
return complex<T>(
|
||||
(real() * rhs.real() + imag() * rhs.imag()) / d,
|
||||
(imag() * rhs.real() - real() * rhs.imag()) / d
|
||||
);
|
||||
}
|
||||
|
||||
/// Scalar Division
|
||||
|
||||
1233
include/cutlass/constants.h
Normal file
1233
include/cutlass/constants.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -439,6 +439,12 @@ Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
|
||||
return Coord<4>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 5-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<5> make_Coord(int _0, int _1, int _2, int _3, int _4) {
|
||||
int values[5] = {_0, _1, _2, _3, _4};
|
||||
return Coord<5>(values);
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -31,18 +31,43 @@
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Output operator for CUDA built-in dim3 type
|
||||
inline std::ostream &operator<<(std::ostream &out, dim3 d) {
|
||||
return out << d.x << ", " << d.y << ", " << d.z;
|
||||
}
|
||||
|
||||
/// Output operator for CUDA built-in error type
|
||||
inline std::ostream &operator<<(std::ostream &out, cudaError_t error) {
|
||||
return out << cudaGetErrorString(error);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// stream operators for cutlass namespace //
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element, int Rank>
|
||||
inline
|
||||
std::ostream& operator<<(std::ostream& out, Array<Element, Rank> const& v) {
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
out << (i ? ", " : "") << v[i];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <int Rank>
|
||||
inline
|
||||
std::ostream& operator<<(std::ostream& out, Coord<Rank> const& coord) {
|
||||
@ -115,7 +140,7 @@ inline std::ostream &operator<<(std::ostream &out, ScalarIO<uint8_t> const &scal
|
||||
/// Default printing to ostream for MatrixShape
|
||||
template <int Row, int Column>
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, cutlass::MatrixShape<Row, Column> const &matrix_shape) {
|
||||
std::ostream & operator<<(std::ostream &out, MatrixShape<Row, Column> const &matrix_shape) {
|
||||
out << "cutlass::MatrixShape::(kRow, kColumn) {"
|
||||
<< cutlass::MatrixShape<Row,Column>::kRow <<","
|
||||
<< cutlass::MatrixShape<Row,Column>::kColumn <<"}";
|
||||
@ -130,7 +155,7 @@ namespace gemm {
|
||||
/// Default printing to ostream for GemmShape
|
||||
template <int M, int N, int K>
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, cutlass::gemm::GemmShape<M,N,K> const &gemm_shape) {
|
||||
std::ostream & operator<<(std::ostream &out, GemmShape<M,N,K> const &gemm_shape) {
|
||||
out << "cutlass::GemmShape::(kM, kN, kK) {"
|
||||
<< cutlass::gemm::GemmShape<M,N,K>::kM <<","
|
||||
<< cutlass::gemm::GemmShape<M,N,K>::kN <<","
|
||||
@ -150,7 +175,7 @@ namespace layout {
|
||||
/// Default printing to ostream for PitchLinearShape
|
||||
template < int Contiguous, int Strided>
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, cutlass::layout::PitchLinearShape<Contiguous, Strided> const &pitch_linear_shape) {
|
||||
std::ostream & operator<<(std::ostream &out, PitchLinearShape<Contiguous, Strided> const &pitch_linear_shape) {
|
||||
out << "cutlass::layout::PitchLinearShape::(kContiguous, kStrided) {"
|
||||
<< cutlass::layout::PitchLinearShape<Contiguous,Strided>::kContiguous <<","
|
||||
<< cutlass::layout::PitchLinearShape<Contiguous,Strided>::kStrided <<"}";
|
||||
|
||||
@ -125,13 +125,6 @@ static char const* cutlassGetStatusString(cutlass::Status status) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
struct Debug {
|
||||
typename T::X x;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static const int NUM_THREADS_PER_WARP = 32;
|
||||
static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2;
|
||||
static const int NUM_THREADS_PER_QUAD = 4;
|
||||
@ -143,7 +136,7 @@ static const int NUM_THREADS_PER_QUAD_PAIR = NUM_THREADS_PER_QUAD * 2;
|
||||
CUTLASS_DEVICE
|
||||
int LaneId() {
|
||||
int ret;
|
||||
asm ("mov.u32 %0, %%laneid;" : "=r"(ret));
|
||||
asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : );
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -151,7 +144,7 @@ int LaneId() {
|
||||
CUTLASS_DEVICE
|
||||
int SmId() {
|
||||
int ret;
|
||||
asm ("mov.u32 %0, %%smid;" : "=r"(ret));
|
||||
asm ("mov.u32 %0, %%smid;" : "=r"(ret) : );
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@ -31,9 +31,8 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/constants.h"
|
||||
#include "cutlass/complex.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/functional.h"
|
||||
@ -108,6 +107,40 @@ struct Sigmoid<Array<T, N> > {
|
||||
}
|
||||
};
|
||||
|
||||
// GELU operator
|
||||
template <typename T>
|
||||
struct GELU {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar) const {
|
||||
return T(cutlass::constants::half<T>() * scalar *
|
||||
(cutlass::constants::one<T>() + erff( scalar / cutlass::constants::root_two<T>() )));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GELU<float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(float const &scalar) const {
|
||||
return cutlass::constants::half<float>() * scalar *
|
||||
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct GELU<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> y;
|
||||
GELU<T> gelu_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < int(rhs.size()); ++i) {
|
||||
y[i] = gelu_op(rhs[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -95,6 +95,13 @@ public:
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha
|
||||
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
@ -102,6 +109,13 @@ public:
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
@ -236,6 +236,18 @@ public:
|
||||
using ElementAccumulator = int;
|
||||
using ElementCompute = float;
|
||||
|
||||
static_assert(
|
||||
platform::is_same<ElementOutput, int32_t>::value ||
|
||||
platform::is_same<ElementOutput, uint32_t>::value ||
|
||||
platform::is_same<ElementOutput, int16_t>::value ||
|
||||
platform::is_same<ElementOutput, uint16_t>::value ||
|
||||
platform::is_same<ElementOutput, int8_t>::value ||
|
||||
platform::is_same<ElementOutput, uint8_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint1b_t>::value,
|
||||
"This elementwise op expects the output to be int.");
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
@ -392,8 +404,9 @@ public:
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
/// Note: The below method only works for small k dimensions. The default
|
||||
/// approach is above
|
||||
/// Note: The below method only when problem_size_K <= 256 for signed int8 gemm
|
||||
/// or problem_size_K <= 128 for unsigned int8 gemm. The default approach is
|
||||
/// above.
|
||||
/// TODO: Add logic to fallback to the default approach
|
||||
template <
|
||||
/// Data type used to load and store< tensors
|
||||
@ -408,6 +421,18 @@ class FastLinearCombinationClamp {
|
||||
using ElementAccumulator = int;
|
||||
using ElementCompute = float;
|
||||
|
||||
static_assert(
|
||||
platform::is_same<ElementOutput, int32_t>::value ||
|
||||
platform::is_same<ElementOutput, uint32_t>::value ||
|
||||
platform::is_same<ElementOutput, int16_t>::value ||
|
||||
platform::is_same<ElementOutput, uint16_t>::value ||
|
||||
platform::is_same<ElementOutput, int8_t>::value ||
|
||||
platform::is_same<ElementOutput, uint8_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint1b_t>::value,
|
||||
"This elementwise op expects the output to be int.");
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
|
||||
206
include/cutlass/epilogue/thread/linear_combination_gelu.h
Normal file
206
include/cutlass/epilogue/thread/linear_combination_gelu.h
Normal file
@ -0,0 +1,206 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination with GELU operations used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
class LinearCombinationGELU {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementOutput_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using ComputeFragment = Array<ElementCompute, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta
|
||||
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationGELU(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes: D = gelu( alpha * accumulator + beta * source )
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentOutput const &source) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_source;
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
GELU<ComputeFragment> gelu;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
intermediate = gelu(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
/// Computes: D = gelu( alpha * accumulator )
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_accumulator;
|
||||
GELU<ComputeFragment> gelu;
|
||||
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
intermediate = gelu(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
@ -78,7 +78,8 @@ template <
|
||||
int ElementsPerAccess,
|
||||
/// Multiply-add operator
|
||||
/// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
||||
typename Operator_ = arch::OpMultiplyAddComplex>
|
||||
typename Operator_ = arch::OpMultiplyAddComplex
|
||||
>
|
||||
struct DefaultEpilogueComplexTensorOp {
|
||||
|
||||
using Shape = Shape_;
|
||||
@ -87,7 +88,6 @@ struct DefaultEpilogueComplexTensorOp {
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
using Operator = Operator_;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
||||
@ -164,7 +164,8 @@ template <
|
||||
>
|
||||
struct DefaultEpilogueComplexTensorOp <Shape_, WarpMmaTensorOp_, PartitionsK,
|
||||
OutputOp_, ElementsPerAccess,
|
||||
arch::OpMultiplyAddGaussianComplex> {
|
||||
arch::OpMultiplyAddGaussianComplex
|
||||
> {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
||||
@ -172,7 +173,6 @@ struct DefaultEpilogueComplexTensorOp <Shape_, WarpMmaTensorOp_, PartitionsK,
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
using Operator = arch::OpMultiplyAddGaussianComplex;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
||||
|
||||
@ -251,7 +251,6 @@ struct DefaultEpilogueTensorOp {
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
||||
|
||||
@ -68,7 +68,7 @@ struct DefaultThreadMapSimt {
|
||||
|
||||
static_assert(
|
||||
!(ThreadblockShape::kM % WarpShape::kM) &&
|
||||
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
|
||||
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = gemm::GemmShape<
|
||||
|
||||
@ -69,7 +69,7 @@ struct DefaultThreadMapTensorOp {
|
||||
|
||||
static_assert(
|
||||
!(ThreadblockShape::kM % WarpShape::kM) &&
|
||||
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
|
||||
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = gemm::GemmShape<
|
||||
@ -119,7 +119,7 @@ struct DefaultInterleavedThreadMapTensorOp {
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
|
||||
!(ThreadblockShape::kM % WarpShape::kM),
|
||||
!(ThreadblockShape::kN % WarpShape::kN),
|
||||
"Divisibility");
|
||||
|
||||
/// Number of warps
|
||||
|
||||
@ -88,7 +88,7 @@ struct DefaultThreadMapVoltaTensorOp<
|
||||
|
||||
static_assert(
|
||||
!(ThreadblockShape::kM % WarpShape::kM) &&
|
||||
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
|
||||
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = gemm::GemmShape<
|
||||
@ -169,7 +169,7 @@ struct DefaultThreadMapVoltaTensorOp<
|
||||
|
||||
static_assert(
|
||||
!(ThreadblockShape::kM % WarpShape::kM) &&
|
||||
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
|
||||
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = gemm::GemmShape<
|
||||
|
||||
@ -71,7 +71,7 @@ struct DefaultThreadMapWmmaTensorOp {
|
||||
|
||||
static_assert(
|
||||
!(ThreadblockShape::kM % WarpShape::kM) &&
|
||||
!(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
|
||||
!(ThreadblockShape::kN % WarpShape::kN), "Divisibility");
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = gemm::GemmShape<
|
||||
|
||||
@ -104,7 +104,6 @@ public:
|
||||
using OutputOp = OutputOp_;
|
||||
using Padding = Padding_;
|
||||
|
||||
/// Output layout is always row-major
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
@ -164,7 +163,7 @@ public:
|
||||
typename Base::SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
int lane_idx ///< Id of thread within warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
shared_load_iterator_(shared_storage.reference(), thread_idx) { }
|
||||
@ -192,7 +191,8 @@ private:
|
||||
void compute_source_not_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators) { ///< Complete warp-level accumulator tile
|
||||
AccumulatorTile const &accumulators ///< Complete warp-level accumulator tile
|
||||
) {
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
@ -259,9 +259,9 @@ private:
|
||||
// Store the final result
|
||||
//
|
||||
|
||||
destination_iterator.store(output_fragment);
|
||||
destination_iterator.store(output_fragment);
|
||||
++destination_iterator;
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -272,7 +272,8 @@ private:
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
) {
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
|
||||
|
||||
@ -41,6 +41,7 @@
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -54,7 +55,7 @@ namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator used to load output tile from shared memory in epilogue.
|
||||
/// Tile iterator used to load and store output tile from shared memory in epilogue.
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator
|
||||
///
|
||||
@ -243,8 +244,8 @@ public:
|
||||
TensorCoord extent,
|
||||
int thread_idx,
|
||||
TensorCoord threadblock_offset = TensorCoord()
|
||||
):
|
||||
params_(params) {
|
||||
): params_(params)
|
||||
{
|
||||
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
|
||||
|
||||
@ -336,7 +337,8 @@ public:
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
load_with_byte_offset(frag, 0);
|
||||
load_with_byte_offset(frag, 0);
|
||||
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
@ -397,7 +399,8 @@ public:
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) {
|
||||
store_with_byte_offset(frag, 0);
|
||||
store_with_byte_offset(frag, 0);
|
||||
|
||||
}
|
||||
|
||||
/// Advances to the next position to load or store
|
||||
|
||||
@ -60,8 +60,8 @@ struct TensorOpPolicy<WarpShape, OperatorShape, layout::RowMajor> {
|
||||
|
||||
/// Number of operations
|
||||
using OperatorCount = MatrixShape<
|
||||
WarpShape::kM / OperatorShape::kM,
|
||||
WarpShape::kN / OperatorShape::kN
|
||||
(WarpShape::kM + OperatorShape::kM - 1) / OperatorShape::kM,
|
||||
(WarpShape::kN + OperatorShape::kN - 1) / OperatorShape::kN
|
||||
>;
|
||||
|
||||
//
|
||||
@ -70,6 +70,8 @@ struct TensorOpPolicy<WarpShape, OperatorShape, layout::RowMajor> {
|
||||
|
||||
static int const kElementsPerAccess = 2;
|
||||
static int const kRowsPerIteration = 8;
|
||||
static bool const kDivisible =
|
||||
!(WarpShape::kM % OperatorShape::kM) && !(WarpShape::kN % OperatorShape::kN);
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
|
||||
@ -29,6 +29,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
|
||||
@ -116,6 +117,9 @@ private:
|
||||
/// Internal layout object
|
||||
Layout layout_;
|
||||
|
||||
/// Thread offset
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default constructor
|
||||
@ -129,12 +133,16 @@ public:
|
||||
unsigned lane_id
|
||||
):
|
||||
pointer_(reinterpret_cast<AccessType *>(ref.data())),
|
||||
layout_(ref.stride()[0] / Policy::kElementsPerAccess) {
|
||||
layout_(ref.stride()[0] / Policy::kElementsPerAccess) {
|
||||
|
||||
int quad_id = (lane_id / Detail::kLanesInQuad);
|
||||
int lane_in_quad = (lane_id % Detail::kLanesInQuad);
|
||||
|
||||
pointer_ += layout_({quad_id, lane_in_quad});
|
||||
thread_offset_ = {
|
||||
quad_id, lane_in_quad * Policy::kElementsPerAccess
|
||||
};
|
||||
|
||||
pointer_ += layout_({thread_offset_.row(), thread_offset_.column() / Policy::kElementsPerAccess});
|
||||
}
|
||||
|
||||
/// Adds a pointer offset
|
||||
@ -148,9 +156,16 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorTensorOp & add_tile_offset(TensorCoord const &tile_offset) {
|
||||
|
||||
pointer_ += layout_({
|
||||
MatrixCoord coord_offset(
|
||||
tile_offset.row() * Shape::kRow,
|
||||
(tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess)
|
||||
tile_offset.column() * Shape::kColumn
|
||||
);
|
||||
|
||||
thread_offset_ += coord_offset;
|
||||
|
||||
pointer_ += layout_({
|
||||
coord_offset.row(),
|
||||
coord_offset.column() / Policy::kElementsPerAccess
|
||||
});
|
||||
|
||||
return *this;
|
||||
@ -198,6 +213,235 @@ public:
|
||||
void load(Fragment &frag) const {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorTensorOp & operator++() {
|
||||
return add_tile_offset({1, 0});
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template for reading and writing tiles of accumulators to shared memory
|
||||
template <
|
||||
typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape)
|
||||
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
|
||||
typename Element_, ///< data type of element to be written
|
||||
typename Layout_
|
||||
>
|
||||
class TileIteratorTensorOpCanonical {
|
||||
public:
|
||||
|
||||
using WarpShape = WarpShape_;
|
||||
using OperatorShape = OperatorShape_;
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>; ///< Tensor Reference object
|
||||
using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor
|
||||
using Index = typename TensorRef::Index;
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
using Policy = TensorOpPolicy<WarpShape, OperatorShape, Layout>;
|
||||
|
||||
static int const kAccessSize = 1;
|
||||
static int const kAccessCount = Policy::kElementsPerAccess / kAccessSize;
|
||||
|
||||
/// Shape of the tile in memory
|
||||
using Shape = MatrixShape<
|
||||
Policy::kRowsPerIteration,
|
||||
WarpShape::kN
|
||||
>;
|
||||
|
||||
/// This is the fragment size produced by one access of the iterator.
|
||||
using Fragment = Array<
|
||||
Element,
|
||||
Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>;
|
||||
|
||||
/// This is the complete warp-level accumulator tile.
|
||||
//using AccumulatorTile = typename Operator::FragmentC;
|
||||
|
||||
/// Number of times this iterator can be incremented
|
||||
static int const kIterations = Policy::kIterations;
|
||||
|
||||
// Internal constants
|
||||
struct Detail {
|
||||
static int const kLanesInQuad = 4;
|
||||
};
|
||||
|
||||
/// Padding quantity
|
||||
using Padding = MatrixShape<
|
||||
0,
|
||||
Detail::kLanesInQuad * Policy::kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
|
||||
/// Storage type for accessing memory
|
||||
using AccessType = AlignedArray<Element, kAccessSize>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Internal pointer to memory
|
||||
AccessType *pointer_;
|
||||
|
||||
/// Internal layout object
|
||||
Layout layout_;
|
||||
|
||||
/// Guard to indicate whether the shape is divisible
|
||||
bool divisible_;
|
||||
|
||||
/// Extent of the output tensor
|
||||
MatrixCoord extent_;
|
||||
|
||||
/// Thread offset
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorTensorOpCanonical(): pointer_(nullptr) { }
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorTensorOpCanonical(
|
||||
TensorRef const &ref,
|
||||
unsigned lane_id
|
||||
):
|
||||
pointer_(reinterpret_cast<AccessType *>(ref.data())),
|
||||
layout_(ref.stride()[0]),
|
||||
divisible_(true),
|
||||
extent_(WarpShape::kM, WarpShape::kN) {
|
||||
|
||||
int quad_id = (lane_id / Detail::kLanesInQuad);
|
||||
int lane_in_quad = (lane_id % Detail::kLanesInQuad);
|
||||
|
||||
thread_offset_ = {
|
||||
quad_id, lane_in_quad * Policy::kElementsPerAccess
|
||||
};
|
||||
|
||||
pointer_ += layout_({thread_offset_.row(), thread_offset_.column()});
|
||||
}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorTensorOpCanonical(
|
||||
TensorRef const &ref,
|
||||
TensorCoord const &extent,
|
||||
unsigned lane_id
|
||||
):
|
||||
pointer_(reinterpret_cast<AccessType *>(ref.data())),
|
||||
layout_(ref.stride()[0]),
|
||||
divisible_(false),
|
||||
extent_(extent) {
|
||||
|
||||
int quad_id = (lane_id / Detail::kLanesInQuad);
|
||||
int lane_in_quad = (lane_id % Detail::kLanesInQuad);
|
||||
|
||||
thread_offset_ = {
|
||||
quad_id, lane_in_quad * Policy::kElementsPerAccess
|
||||
};
|
||||
|
||||
pointer_ += layout_({thread_offset_.row(), thread_offset_.column()});
|
||||
}
|
||||
|
||||
/// Adds a pointer offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorTensorOpCanonical & add_pointer_offset(Index pointer_offset) {
|
||||
pointer_ += pointer_offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorTensorOpCanonical & add_tile_offset(TensorCoord const &tile_offset) {
|
||||
|
||||
MatrixCoord coord_offset(
|
||||
tile_offset.row() * Shape::kRow,
|
||||
tile_offset.column() * Shape::kColumn
|
||||
);
|
||||
|
||||
thread_offset_ += coord_offset;
|
||||
|
||||
pointer_ += layout_({
|
||||
coord_offset.row(),
|
||||
coord_offset.column()
|
||||
});
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorTensorOpCanonical & operator+=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Store
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
||||
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int a = 0; a < kAccessCount; ++a) {
|
||||
|
||||
int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a;
|
||||
int frag_idx = n * kAccessCount + a;
|
||||
|
||||
int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a;
|
||||
|
||||
if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) {
|
||||
pointer_[ptr_idx] = frag_ptr[frag_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Store
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store(Fragment const &frag) {
|
||||
store_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Load
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int a = 0; a < kAccessCount; ++a) {
|
||||
|
||||
int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a;
|
||||
int frag_idx = n * kAccessCount + a;
|
||||
|
||||
int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a;
|
||||
|
||||
if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) {
|
||||
frag_ptr[frag_idx] = pointer_[ptr_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorTensorOpCanonical & operator++() {
|
||||
return add_tile_offset({1, 0});
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -33,6 +33,7 @@
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
|
||||
#include "cutlass/epilogue/warp/tensor_op_policy.h"
|
||||
#include "cutlass/epilogue/warp/volta_tensor_op_policy.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -40,6 +41,8 @@
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/******************************************************************************
|
||||
* Static math utilities
|
||||
******************************************************************************/
|
||||
@ -136,6 +139,19 @@ CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
|
||||
return temp ? (a / temp * b) : 0;
|
||||
}
|
||||
|
||||
/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b
|
||||
CUTLASS_HOST_DEVICE
|
||||
constexpr int round_up(int a, int b) {
|
||||
return ((a + b - 1) / b) * b;
|
||||
}
|
||||
|
||||
/// Returns the ceiling of (a / b)
|
||||
CUTLASS_HOST_DEVICE
|
||||
constexpr int ceil_div(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* log2 computation, what's the
|
||||
* difference between the below codes and
|
||||
@ -189,7 +205,6 @@ void fast_divmod(int& quo, int& rem, int src, int div, unsigned int mul, unsigne
|
||||
|
||||
// The remainder.
|
||||
rem = src - (quo * div);
|
||||
|
||||
}
|
||||
|
||||
// For long int input
|
||||
@ -206,17 +221,56 @@ void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul,
|
||||
rem = src - (quo * div);
|
||||
}
|
||||
|
||||
/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b
|
||||
CUTLASS_HOST_DEVICE
|
||||
int round_up(int a, int b) {
|
||||
return ((a + b - 1) / b) * b;
|
||||
}
|
||||
/// Object to encapsulate the fast division+modulus operation.
|
||||
///
|
||||
/// This object precomputes two values used to accelerate the computation and is best used
|
||||
/// when the divisor is a grid-invariant. In this case, it may be computed in host code and
|
||||
/// marshalled along other kernel arguments using the 'Params' pattern.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
///
|
||||
/// int quotient, remainder, dividend, divisor;
|
||||
///
|
||||
/// FastDivmod divmod(divisor);
|
||||
///
|
||||
/// divmod(quotient, remainder, dividend);
|
||||
///
|
||||
/// // quotient = (dividend / divisor)
|
||||
/// // remainder = (dividend % divisor)
|
||||
///
|
||||
struct FastDivmod {
|
||||
|
||||
/// Returns the ceiling of (a / b)
|
||||
CUTLASS_HOST_DEVICE
|
||||
int ceil_div(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
int divisor;
|
||||
unsigned int multiplier;
|
||||
unsigned int shift_right;
|
||||
|
||||
/// Construct the FastDivmod object, in host code ideally.
|
||||
///
|
||||
/// This precomputes some values based on the divisor and is computationally expensive.
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
FastDivmod(): divisor(0), multiplier(0), shift_right(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
FastDivmod(int divisor_): divisor(divisor_) {
|
||||
find_divisor(multiplier, shift_right, divisor);
|
||||
}
|
||||
|
||||
/// Computes integer division and modulus using precomputed values. This is computationally
|
||||
/// inexpensive.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(int "ient, int &remainder, int dividend) const {
|
||||
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
|
||||
}
|
||||
|
||||
/// Computes integer division and modulus using precomputed values. This is computationally
|
||||
/// inexpensive.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(int "ient, int64_t &remainder, int64_t dividend) const {
|
||||
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
|
||||
}
|
||||
};
|
||||
|
||||
/******************************************************************************
|
||||
* Min/Max
|
||||
@ -242,4 +296,117 @@ constexpr int const_max(int a, int b) {
|
||||
return (b > a ? b : a);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float fast_cos(float theta) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::cosf(theta);
|
||||
#else
|
||||
return std::cos(theta);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
double fast_cos(double theta) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::cos(theta);
|
||||
#else
|
||||
return std::cos(theta);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float fast_sin(float theta) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::sinf(theta);
|
||||
#else
|
||||
return std::sin(theta);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
double fast_sin(double theta) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::sin(theta);
|
||||
#else
|
||||
return std::sin(theta);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float fast_acos(float theta) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::acosf(theta);
|
||||
#else
|
||||
return std::acos(theta);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
double fast_acos(double theta) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::acos(theta);
|
||||
#else
|
||||
return std::acos(theta);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float fast_asin(float theta) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::asinf(theta);
|
||||
#else
|
||||
return std::asin(theta);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
double fast_asin(double theta) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::asin(theta);
|
||||
#else
|
||||
return std::asin(theta);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float fast_sqrt(float theta) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::sqrtf(theta);
|
||||
#else
|
||||
return std::sqrt(theta);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
double fast_sqrt(double theta) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::sqrt(theta);
|
||||
#else
|
||||
return std::sqrt(theta);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float fast_log(float x) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::logf(x);
|
||||
#else
|
||||
return std::log(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
double fast_log(double x) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::log(x);
|
||||
#else
|
||||
return std::log(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -32,9 +32,7 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/half.h"
|
||||
|
||||
@ -69,6 +67,82 @@ struct multiplies {
|
||||
}
|
||||
};
|
||||
|
||||
/// Squares with optional conversion
|
||||
template <typename T, typename Output = T>
|
||||
struct square {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Output operator()(T lhs) const {
|
||||
multiplies<Output> mul_op;
|
||||
|
||||
Output y = Output(lhs);
|
||||
return mul_op(y, y);
|
||||
}
|
||||
};
|
||||
|
||||
/// Returns the magnitude squared of an element.
|
||||
template <typename T, typename Output = T>
|
||||
struct magnitude_squared {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Output operator()(T lhs) const {
|
||||
multiplies<Output> mul_op;
|
||||
|
||||
Output y = Output(lhs);
|
||||
return mul_op(y, y);
|
||||
}
|
||||
};
|
||||
|
||||
/// Squares with optional conversion
|
||||
template <typename T, typename Output>
|
||||
struct magnitude_squared<complex<T>, Output> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Output operator()(complex<T> lhs) const {
|
||||
multiplies<Output> mul_op;
|
||||
|
||||
Output y_r = Output(lhs.real());
|
||||
Output y_i = Output(lhs.imag());
|
||||
|
||||
return mul_op(y_r, y_r) + mul_op(y_i, y_i);
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes the square of a difference with optional conversion
|
||||
template <typename T, typename Output = T>
|
||||
struct square_difference {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Output operator()(T lhs, T rhs) const {
|
||||
multiplies<Output> mul_op;
|
||||
|
||||
Output y = Output(lhs) - Output(rhs);
|
||||
return mul_op(y, y);
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes the square of a difference with optional conversion
|
||||
template <typename T, typename Output = T>
|
||||
struct magnitude_squared_difference {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Output operator()(T lhs, T rhs) const {
|
||||
multiplies<Output> mul_op;
|
||||
|
||||
Output y = Output(lhs) - Output(rhs);
|
||||
return mul_op(y, y);
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes the square of a difference with optional conversion
|
||||
template <typename T, typename Output>
|
||||
struct magnitude_squared_difference<complex<T>, Output> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Output operator()(complex<T> lhs, complex<T> rhs) const {
|
||||
multiplies<Output> mul_op;
|
||||
|
||||
Output y_r = Output(lhs.real()) - Output(rhs.real());
|
||||
Output y_i = Output(lhs.imag()) - Output(rhs.imag());
|
||||
|
||||
return mul_op(y_r, y_r) + mul_op(y_i, y_i);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct divides {
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
517
include/cutlass/gemm/device/gemm_sparse.h
Normal file
517
include/cutlass/gemm/device/gemm_sparse.h
Normal file
@ -0,0 +1,517 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/sparse_gemm.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_sparse.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may
|
||||
be invoked from host code.
|
||||
|
||||
The contributions of this class are:
|
||||
|
||||
1. At compile time, it maps data types and high-level structural parameters onto
|
||||
specific CUTLASS components.
|
||||
|
||||
2. At runtime, it maps logical arguments to GEMM problems to kernel parameters.
|
||||
|
||||
3. At runtime, it launches kernels on the device.
|
||||
|
||||
The intent is to provide a convenient mechanism for interacting with most plausible GEMM
|
||||
configurations for each supported architecture. Consequently, not all parameters are exposed
|
||||
to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy
|
||||
are selected to tradeoff simplicity of the interface with flexibility. We expect
|
||||
most configurations to be specified at this level. Applications with more exotic requirements
|
||||
may construct their kernels of interest using CUTLASS components at the threadblock, warp,
|
||||
and thread levels of abstraction.
|
||||
|
||||
CUTLASS exposes computations using the functor design pattern in which objects compose some
|
||||
internal state with an overloaded function call operator. This enables decoupling of
|
||||
initialization from execution, possibly reducing overhead during steady state phases of
|
||||
application execution.
|
||||
|
||||
CUTLASS device-level operators expose an Arguments structure encompassing each logical
|
||||
input to the computation. This is distinct from the kernel-level Params structure pattern
|
||||
which contains application-specific precomputed state needed by the device code.
|
||||
|
||||
Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN
|
||||
is as follows:
|
||||
|
||||
//
|
||||
// Instantiate the CUTLASS GEMM operator.
|
||||
//
|
||||
|
||||
cutlass::gemm::device::Gemm<
|
||||
float,
|
||||
cutlass::layout::ColumnMajor,
|
||||
float,
|
||||
cutlass::layout::ColumnMajor,
|
||||
float,
|
||||
cutlass::layout::ColumnMajor
|
||||
> gemm_op;
|
||||
|
||||
//
|
||||
// Launch the GEMM operation on the device
|
||||
//
|
||||
|
||||
cutlass::Status status = gemm_op({
|
||||
{m, n, k}, // GemmCoord problem_size,
|
||||
{A, lda}, // TensorRef<float, layout::ColumnMajor> ref_A,
|
||||
{B, ldb}, // TensorRef<float, layout::ColumnMajor> ref_B,
|
||||
{C, ldc}, // TensorRef<float, layout::ColumnMajor> ref_C,
|
||||
{D, ldd}, // TensorRef<float, layout::ColumnMajor> ref_D,
|
||||
{alpha, beta} // EpilogueOutputOp::Params epilogue_op_params
|
||||
});
|
||||
|
||||
|
||||
A simplified view of the template is listed below.
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages
|
||||
>
|
||||
class Gemm;
|
||||
*/
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_ = ElementC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_ = arch::OpClassSimt,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_ = arch::Sm70,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ =
|
||||
typename threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// If true, kernel supports split-K with serial reduction
|
||||
bool SplitKSerial = false,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator,
|
||||
/// Whether Beta is zero or not
|
||||
bool IsBetaZero = false>
|
||||
class SparseGemm {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static bool const kIsBetaZero = IsBetaZero;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Define the kernel
|
||||
using GemmKernel = typename kernel::DefaultSparseGemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator,
|
||||
kIsBetaZero
|
||||
>::GemmKernel;
|
||||
|
||||
using ElementE = typename GemmKernel::ElementE;
|
||||
|
||||
using LayoutE = typename GemmKernel::LayoutE;
|
||||
|
||||
static int const kAlignmentE = 128 / sizeof_bits<ElementE>::value;
|
||||
|
||||
static int const kSparse = GemmKernel::kSparse;
|
||||
static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits;
|
||||
static int const kElementsPerElementE = GemmKernel::kElementsPerElementE;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord problem_size;
|
||||
TensorRef<ElementA const, LayoutA> ref_A;
|
||||
TensorRef<ElementB const, LayoutB> ref_B;
|
||||
TensorRef<ElementC const, LayoutC> ref_C;
|
||||
TensorRef<ElementC, LayoutC> ref_D;
|
||||
TensorRef<ElementE const, LayoutE> ref_E;
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
int split_k_slices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord problem_size_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C_,
|
||||
TensorRef<ElementC, LayoutC> ref_D_,
|
||||
TensorRef<ElementE, LayoutE> ref_E_,
|
||||
typename EpilogueOutputOp::Params epilogue_ =
|
||||
typename EpilogueOutputOp::Params(),
|
||||
int split_k_slices = 1
|
||||
):
|
||||
problem_size(problem_size_),
|
||||
ref_A(ref_A_),
|
||||
ref_B(ref_B_),
|
||||
ref_C(ref_C_),
|
||||
ref_D(ref_D_),
|
||||
ref_E(ref_E_),
|
||||
epilogue(epilogue_),
|
||||
split_k_slices(split_k_slices) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
SparseGemm() { }
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
if (!kSplitKSerial && args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
Status status = GemmKernel::can_implement(
|
||||
args.problem_size,
|
||||
args.ref_A.non_const_ref(),
|
||||
args.ref_B.non_const_ref(),
|
||||
args.ref_C.non_const_ref(),
|
||||
args.ref_D,
|
||||
args.ref_E.non_const_ref()
|
||||
);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
|
||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
}
|
||||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial) {
|
||||
if (args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
size_t bytes = get_workspace_size(args);
|
||||
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if (args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params{
|
||||
args.problem_size,
|
||||
grid_shape,
|
||||
args.ref_A.non_const_ref(),
|
||||
args.ref_B.non_const_ref(),
|
||||
args.ref_C.non_const_ref(),
|
||||
args.ref_D,
|
||||
args.ref_E.non_const_ref(),
|
||||
args.epilogue,
|
||||
static_cast<int *>(workspace)
|
||||
};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
}
|
||||
|
||||
params_.ref_A.reset(args.ref_A.non_const_ref().data());
|
||||
params_.ref_B.reset(args.ref_B.non_const_ref().data());
|
||||
params_.ref_C.reset(args.ref_C.non_const_ref().data());
|
||||
params_.ref_D.reset(args.ref_D.data());
|
||||
params_.ref_E.reset(args.ref_E.non_const_ref().data());
|
||||
params_.output_op = args.epilogue;
|
||||
params_.semaphore = static_cast<int *>(workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
cudaError_t result;
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -30,6 +30,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
@ -42,6 +44,8 @@
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -121,13 +125,30 @@ public:
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
||||
grid.z <= std::numeric_limits<uint16_t>::max())) {
|
||||
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()");
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
@ -151,28 +172,41 @@ public:
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()");
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
|
||||
return threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
|
||||
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10)) {
|
||||
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
@ -182,6 +216,7 @@ public:
|
||||
smem_size);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
}
|
||||
@ -195,6 +230,11 @@ public:
|
||||
0);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
@ -216,27 +256,43 @@ public:
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
return std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
|
||||
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning internal error");
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
if (workspace_bytes) {
|
||||
|
||||
if (!workspace) {
|
||||
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
|
||||
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm) {
|
||||
CUTLASS_TRACE_HOST(" clearing device workspace");
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
@ -262,6 +318,8 @@ public:
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace) {
|
||||
@ -275,6 +333,7 @@ public:
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
@ -302,11 +361,19 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block
|
||||
<< "), SMEM: " << smem_size << " bytes");
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
|
||||
@ -96,7 +96,7 @@ struct GemmCoord : public Coord<3, int> {
|
||||
/// Integer-valued index
|
||||
typedef int Index;
|
||||
|
||||
/// Base type is a Coord of rank=4
|
||||
/// Base type is a Coord of rank=3
|
||||
typedef Coord<3, Index> Base;
|
||||
|
||||
/// GEMM M dimension - rows of the output C matrix
|
||||
@ -274,7 +274,7 @@ struct BatchedGemmCoord : public Coord<4, int> {
|
||||
/// GEMM K dimension - inner dimension of the GEMM problem
|
||||
static int const kK = 2;
|
||||
|
||||
/// GEMM K dimension - inner dimension of the GEMM problem
|
||||
/// GEMM Batch dimension - inner dimension of the GEMM problem
|
||||
static int const kBatch = 3;
|
||||
|
||||
//
|
||||
|
||||
187
include/cutlass/gemm/kernel/default_gemm_sparse.h
Normal file
187
include/cutlass/gemm/kernel/default_gemm_sparse.h
Normal file
@ -0,0 +1,187 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/wmma.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm.h"
|
||||
#include "cutlass/gemm/kernel/sparse_gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_sparse_mma.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h"
|
||||
#endif //CUTLASS_ARCH_WMMA_ENABLED
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Beta is zero or not
|
||||
bool IsBetaZero = false>
|
||||
struct DefaultSparseGemm;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultSparseGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
Operator>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::SparseGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -175,7 +175,8 @@ struct Gemm {
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
@ -252,7 +253,8 @@ struct Gemm {
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
|
||||
@ -133,7 +133,8 @@ struct GemmArray {
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
@ -207,7 +208,8 @@ struct GemmArray {
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
|
||||
@ -140,7 +140,8 @@ struct GemmBatched {
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
@ -219,7 +220,8 @@ struct GemmBatched {
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
|
||||
@ -66,7 +66,7 @@ __global__ void GemmPipelined(
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);
|
||||
|
||||
if (grid_tiled_shape.m() <= tb_tile_offset.m() ||
|
||||
grid_tiled_shape.n() <= tb_tile_offset.n()) {
|
||||
@ -131,7 +131,7 @@ __global__ void GemmPipelined(
|
||||
warp_id,
|
||||
lane_id);
|
||||
|
||||
tb_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
tb_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
|
||||
@ -419,7 +419,8 @@ public:
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
@ -549,7 +550,8 @@ public:
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
|
||||
@ -376,7 +376,8 @@ public:
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
|
||||
@ -128,7 +128,8 @@ struct GemmSplitKParallel {
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
@ -205,7 +206,8 @@ struct GemmSplitKParallel {
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
|
||||
@ -36,6 +36,8 @@
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -154,6 +156,7 @@ public:
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
@ -252,6 +255,7 @@ public:
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::Params() - problem_size: " << problem_size);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -264,9 +268,16 @@ public:
|
||||
ptr_C = const_cast<void *>(args.ptr_C);
|
||||
ptr_D = args.ptr_D;
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
|
||||
output_op = args.epilogue;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||
}
|
||||
};
|
||||
|
||||
@ -289,6 +300,8 @@ public:
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
@ -297,9 +310,12 @@ public:
|
||||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
|
||||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -314,7 +330,8 @@ public:
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
@ -421,7 +438,8 @@ public:
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
|
||||
392
include/cutlass/gemm/kernel/sparse_gemm.h
Normal file
392
include/cutlass/gemm/kernel/sparse_gemm.h
Normal file
@ -0,0 +1,392 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
>
|
||||
struct SparseGemm {
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using OutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
static int const kSparse = Mma::kSparse;
|
||||
static int const kMetaSizeInBits = Mma::kMetaSizeInBits;
|
||||
static int const kMaxID2 = Mma::kMaxID2;
|
||||
static int const kElementsPerElementE = Mma::kElementsPerElementE;
|
||||
|
||||
using ElementE = typename Mma::ElementE;
|
||||
using LayoutE = typename Mma::LayoutE;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
typename Mma::IteratorE::Params params_E;
|
||||
typename Mma::IteratorE::TensorRef ref_E;
|
||||
typename OutputOp::Params output_op;
|
||||
int *semaphore;
|
||||
int gemm_k_iterations;
|
||||
int gemm_k_size;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
typename Mma::IteratorA::TensorRef ref_A,
|
||||
typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D,
|
||||
typename Mma::IteratorE::TensorRef ref_E,
|
||||
typename OutputOp::Params output_op = typename OutputOp::Params(),
|
||||
int *workspace = nullptr
|
||||
):
|
||||
problem_size(problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
params_A(ref_A.layout()),
|
||||
ref_A(ref_A),
|
||||
params_B(ref_B.layout()),
|
||||
ref_B(ref_B),
|
||||
params_C(ref_C.layout()),
|
||||
ref_C(ref_C),
|
||||
params_D(ref_D.layout()),
|
||||
ref_D(ref_D),
|
||||
params_E(ref_E.layout()),
|
||||
ref_E(ref_E),
|
||||
output_op(output_op) {
|
||||
|
||||
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
||||
|
||||
semaphore = workspace;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SparseGemm() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
typename Mma::IteratorA::TensorRef ref_A,
|
||||
typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D,
|
||||
typename Mma::IteratorE::TensorRef ref_E) {
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
static int const kAlignmentE = Mma::IteratorE::AccessType::kElements;
|
||||
|
||||
if (!TensorRef_aligned(ref_A, kAlignmentA)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_B, kAlignmentB)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_C, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_D, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_E, kAlignmentE)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if ((problem_size.m() % kAlignmentA) || ((problem_size.k() / kSparse) % kAlignmentA) ||
|
||||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
|
||||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC) ||
|
||||
(problem_size.m() % kAlignmentE) || ((problem_size.k() / kSparse) % kAlignmentE)) {
|
||||
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
// The k dimension has to be the multiple of the Threadblock k because out
|
||||
// of bound meta data would be initialized to 0 by acync.zfill but 0 is not
|
||||
// a valid meta data.
|
||||
if (problem_size.k() % Mma::Shape::kK) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
// M dimension has to be multiple of 32 (sparse float) or 16 (sparse int)
|
||||
// because of the row reordering of operand E
|
||||
static int const kAlignmentM = (sizeof(ElementE) == 2) ? 32 : 16;
|
||||
|
||||
if (problem_size.m() % kAlignmentM) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_E{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size / kSparse,
|
||||
};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k = min(
|
||||
params.problem_size.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A, B, and E operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
params.ref_A.data(),
|
||||
{params.problem_size.m(), problem_size_k / kSparse},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
params.ref_B.data(),
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
typename Mma::IteratorE iterator_E(
|
||||
params.params_E, params.ref_E.data(),
|
||||
{params.problem_size.m(),
|
||||
problem_size_k / kSparse / kElementsPerElementE},
|
||||
thread_idx, tb_offset_E);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
OutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.params_C,
|
||||
params.ref_C.data(),
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.params_D,
|
||||
params.ref_D.data(),
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
__threadfence();
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -229,6 +229,16 @@ struct Mma<
|
||||
/// C operand storage
|
||||
using FragmentC = Array<ElementC, Shape::kMN>;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename MmaGeneric<
|
||||
Shape,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
Operator>::MmaOp;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
@ -977,6 +977,30 @@ struct Mma<
|
||||
/// C operand storage
|
||||
using FragmentC = Array<ElementC, Shape::kMN>;
|
||||
|
||||
static bool const a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value;
|
||||
static bool const b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value;
|
||||
static bool const c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value;
|
||||
static bool const c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value;
|
||||
|
||||
static bool const m_mod2 = !(Shape::kM % 2);
|
||||
static bool const n_mod2 = !(Shape::kN % 2);
|
||||
static bool const k_mod2 = !(Shape::kK % 2);
|
||||
|
||||
// HFMA based MMA optimizations are of 2 types :
|
||||
// 1. Inner product
|
||||
// 2. Outer product
|
||||
// It is chosen based on LayoutC (for outer product gemm) or
|
||||
// Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms)
|
||||
// If all fails, we choose the generic MMA
|
||||
static bool const use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2);
|
||||
static bool const use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2);
|
||||
static bool const use_optimized = (use_outer_prod || use_inner_prod);
|
||||
|
||||
using ArchMmaOperator = typename platform::conditional< use_optimized,
|
||||
detail::Mma_HFMA2<Shape, LayoutA, LayoutB, LayoutC, use_outer_prod>,
|
||||
MmaGeneric <Shape, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator>
|
||||
>::type;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -989,30 +1013,8 @@ struct Mma<
|
||||
FragmentB const & B,
|
||||
FragmentC const & C) {
|
||||
|
||||
constexpr bool a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value;
|
||||
constexpr bool b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value;
|
||||
constexpr bool c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value;
|
||||
constexpr bool c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value;
|
||||
ArchMmaOperator mma;
|
||||
|
||||
constexpr bool m_mod2 = !(Shape::kM % 2);
|
||||
constexpr bool n_mod2 = !(Shape::kN % 2);
|
||||
constexpr bool k_mod2 = !(Shape::kK % 2);
|
||||
|
||||
// HFMA based MMA optimizations are of 2 types :
|
||||
// 1. Inner product
|
||||
// 2. Outer product
|
||||
// It is chosen based on LayoutC (for outer product gemm) or
|
||||
// Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms)
|
||||
// If all fails, we choose the generic MMA
|
||||
constexpr bool use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2);
|
||||
constexpr bool use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2);
|
||||
constexpr bool use_optimized = (use_outer_prod || use_inner_prod);
|
||||
|
||||
typename platform::conditional< use_optimized,
|
||||
detail::Mma_HFMA2<Shape, LayoutA, LayoutB, LayoutC, use_outer_prod>,
|
||||
MmaGeneric <Shape, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator>
|
||||
>::type mma;
|
||||
|
||||
mma(D, A, B, C);
|
||||
|
||||
}
|
||||
@ -1086,6 +1088,8 @@ struct Mma<
|
||||
using FragmentB = Array<ElementB, Shape::kKN>;
|
||||
using FragmentC = Array<ElementC, Shape::kMN>;
|
||||
|
||||
using ArchMmaOperator = typename TransposeMma::ArchMmaOperator;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentC & D,
|
||||
|
||||
@ -93,6 +93,19 @@ struct Mma<
|
||||
/// C operand storage
|
||||
using FragmentC = Array<ElementC, Shape::kMN>;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
// Use 1x1x4 IDP4A sequence for bulk of computation
|
||||
using ArchMmaOperator = arch::Mma<
|
||||
gemm::GemmShape<1,1,4>,
|
||||
1,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
arch::OpMultiplyAdd>;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -112,22 +125,11 @@ struct Mma<
|
||||
D = C;
|
||||
|
||||
/// Use 1x1x4 IDP4A sequence for bulk of computation
|
||||
using Mma = arch::Mma<
|
||||
gemm::GemmShape<1,1,4>,
|
||||
1,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
arch::OpMultiplyAdd>;
|
||||
|
||||
Mma mma;
|
||||
ArchMmaOperator mma;
|
||||
|
||||
// Compute matrix product
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {
|
||||
for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Shape::kN; ++n) {
|
||||
@ -143,8 +145,8 @@ struct Mma<
|
||||
|
||||
mma(
|
||||
tmp,
|
||||
ptr_A[m * Shape::kK / Mma::Shape::kK + k],
|
||||
ptr_B[n * Shape::kK / Mma::Shape::kK + k],
|
||||
ptr_A[m * Shape::kK / ArchMmaOperator::Shape::kK + k],
|
||||
ptr_B[n * Shape::kK / ArchMmaOperator::Shape::kK + k],
|
||||
tmp);
|
||||
|
||||
d.at(mn) = reinterpret_cast<int32_t &>(tmp);
|
||||
@ -206,6 +208,19 @@ struct Mma<
|
||||
/// C operand storage
|
||||
using FragmentC = Array<ElementC, Shape::kMN>;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
/// Use 1x1x4 IDP4A sequence for bulk of computation
|
||||
using ArchMmaOperator = arch::Mma<
|
||||
gemm::GemmShape<1,1,4>,
|
||||
1,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
arch::OpMultiplyAdd>;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -224,25 +239,15 @@ struct Mma<
|
||||
// Copy accumulators
|
||||
D = C;
|
||||
|
||||
/// Use 1x1x4 IDP4A sequence for bulk of computation
|
||||
using Mma = arch::Mma<
|
||||
gemm::GemmShape<1,1,4>,
|
||||
1,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
arch::OpMultiplyAdd>;
|
||||
|
||||
Mma mma;
|
||||
/// Underlying matrix multiply operator
|
||||
ArchMmaOperator mma;
|
||||
|
||||
Array<int8_t, 4> const *ptr_A = reinterpret_cast<Array<int8_t, 4> const *>(&A);
|
||||
Array<int8_t, 4> const *ptr_B = reinterpret_cast<Array<int8_t, 4> const *>(&B);
|
||||
|
||||
// Compute matrix product
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {
|
||||
for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Shape::kN; ++n) {
|
||||
|
||||
@ -36,6 +36,7 @@
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
|
||||
@ -1,197 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data
|
||||
layout of the global memory fragments, data types, and internal tile sizes.
|
||||
|
||||
Partial specializations for threadblock::Mma operations targeting TensorOp instructions.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_simt.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: column-major
|
||||
/// B: row-major
|
||||
/// InstructionShape: 1-by-1-by-1
|
||||
/// Operator: SIMT
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of B operand
|
||||
typename ElementB_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 1>, ElementA_,
|
||||
layout::ColumnMajor, ElementB_, layout::RowMajor,
|
||||
ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_,
|
||||
> {
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::ColumnMajor;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = layout::RowMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using OperatorClass = arch::OpClassSimt;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<
|
||||
Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK
|
||||
>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) &&
|
||||
!(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
|
||||
);
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
/// Shared memory layout for A operand
|
||||
using SmemLayoutA = layout::ColumnMajor;
|
||||
|
||||
/// Shared memory layout for B operand
|
||||
using SmemLayoutB = layout::RowMajor;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kM, Shape::kK>,
|
||||
kThreads,
|
||||
1
|
||||
>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK>,
|
||||
ElementA,
|
||||
SmemLayoutA,
|
||||
1,
|
||||
IteratorThreadMapA
|
||||
>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kN, Shape::kK>,
|
||||
kThreads,
|
||||
1
|
||||
>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>,
|
||||
ElementB,
|
||||
SmemLayoutB,
|
||||
0,
|
||||
IteratorThreadMapB
|
||||
>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using WarpMma = cutlass::gemm::warp::MmaSimt<
|
||||
WarpShape,
|
||||
ElementA,
|
||||
SmemLayoutA,
|
||||
ElementB,
|
||||
SmemLayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
warp::MmaSimtPolicy<
|
||||
MatrixShape<4, 8>,
|
||||
layout::RowMajorInterleaved<2>,
|
||||
GemmShape<
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementB>::value,
|
||||
1>
|
||||
>
|
||||
>
|
||||
>;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = MmaPolicy<
|
||||
WarpMma,
|
||||
MatrixShape<0, 0>,
|
||||
MatrixShape<0, 0>,
|
||||
WarpCount::kK
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1119,11 +1119,18 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, float,
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: column-major-interleave32
|
||||
/// B: row-major-interleave32
|
||||
/// A: column-major-interleave
|
||||
/// B: row-major-interleave
|
||||
/// Operator: tensor op class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
///
|
||||
/// Column/RowMajorInterleved<InterleavedK>(m, n) is mapped to Column/RowMajor(m
|
||||
/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators
|
||||
/// can be reused. The shared store iterator is the same as the crosswise shared
|
||||
/// store iterator. So, the only thing we need to do is to swap the coordinates
|
||||
/// (contiguous <=> strided) used by the global iterator and the shared store
|
||||
/// iterator.
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
|
||||
@ -1362,6 +1362,13 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
/// Operator: tensor op class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
///
|
||||
/// Column/RowMajorInterleved<InterleavedK>(m, n) is mapped to Column/RowMajor(m
|
||||
/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators
|
||||
/// can be reused. The shared store iterator is the same as the crosswise shared
|
||||
/// store iterator. So, the only thing we need to do is to swap the coordinates
|
||||
/// (contiguous <=> strided) used by the global iterator and the shared store
|
||||
/// iterator.
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
@ -1608,7 +1615,7 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Transpose the ThreadMap of iterator A
|
||||
/// Transpose the ThreadMap of iterator B
|
||||
using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt<IteratorThreadMapB>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
@ -1916,7 +1923,7 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Transpose the ThreadMap of iterator A
|
||||
/// Transpose the ThreadMap of iterator B
|
||||
using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt<IteratorThreadMapB>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
|
||||
828
include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h
Normal file
828
include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h
Normal file
@ -0,0 +1,828 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Defines basic properties needed by CTA-level GEMMs assuming
|
||||
expectations about data layout of the global memory fragments, data types,
|
||||
and internal tile sizes.
|
||||
|
||||
Partial specializations for threadblock::Mma operations targeting sparse
|
||||
TensorOp instructions.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
|
||||
#include "cutlass/layout/tensor_op_multiplicand_sm80.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_simt_policy.h"
|
||||
#include "cutlass/gemm/warp/mma_simt.h"
|
||||
#include "cutlass/gemm/warp/default_mma_sparse_tensor_op.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma_core.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h"
|
||||
#include "cutlass/gemm/threadblock/mma_sparse_multistage.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template defininng default matrix multiply operators inferred from threadblock tile size,
|
||||
/// global memory data layout, and target math instruction.
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator
|
||||
typename Shape,
|
||||
/// Shape of warp-level matrix multiply operator
|
||||
typename WarpShape,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Element data type of A operand
|
||||
typename ElementA,
|
||||
/// Layout of operand A
|
||||
typename LayoutA,
|
||||
/// Element data type of B operand
|
||||
typename ElementB,
|
||||
/// Layout of operand B
|
||||
typename LayoutB,
|
||||
/// Data type of accumulator
|
||||
typename ElementC,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC,
|
||||
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
|
||||
typename OperatorClass,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Operation performed by MMA
|
||||
typename Operator = typename platform::conditional<
|
||||
(platform::is_same<OperatorClass,
|
||||
cutlass::arch::OpClassTensorOp>::value) &&
|
||||
(platform::is_same<ElementA, int8_t>::value ||
|
||||
platform::is_same<ElementA, int4b_t>::value ||
|
||||
platform::is_same<ElementA, uint8_t>::value ||
|
||||
platform::is_same<ElementA, uint4b_t>::value),
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::arch::OpMultiplyAdd>::type,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false
|
||||
/// Cache operation of operand A
|
||||
, cutlass::arch::CacheOperation::Kind CacheOpA =
|
||||
cutlass::arch::CacheOperation::Global,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB =
|
||||
cutlass::arch::CacheOperation::Global
|
||||
>
|
||||
struct DefaultSparseMmaCore;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: column-major
|
||||
/// B: row-major
|
||||
/// Operator: tensor op class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of B operand
|
||||
typename ElementB_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Operation performed by MMA
|
||||
typename Operator_,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultSparseMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
layout::ColumnMajor, ElementB_, layout::RowMajor,
|
||||
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
|
||||
Operator_, false, CacheOpA, CacheOpB> {
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::ColumnMajor;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = layout::RowMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
static int const kSparse = 2;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of a threadblock-scoped access
|
||||
static int const kAccessSizeInBits = 128;
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous<
|
||||
sizeof_bits<ElementA>::value, int(128 / sizeof(ElementA))>;
|
||||
|
||||
// Shared memory layout
|
||||
using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous<
|
||||
sizeof_bits<ElementB>::value, int(128 / sizeof(ElementB))>;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kM, Shape::kK / kSparse>, kThreads,
|
||||
layout::PitchLinearShape<8, 4>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK / kSparse>, ElementA, SmemLayoutA, 1,
|
||||
IteratorThreadMapA>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kN, Shape::kK>, kThreads,
|
||||
layout::PitchLinearShape<8, 4>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 0,
|
||||
IteratorThreadMapB>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp<
|
||||
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
|
||||
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
|
||||
|
||||
/// Cache operation of operand E
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpE =
|
||||
cutlass::arch::CacheOperation::Global;
|
||||
|
||||
static int const kInterleavedE = MmaTensorOp::kInterleaved;
|
||||
static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits;
|
||||
static int const kMaxID2 = MmaTensorOp::kMaxID2;
|
||||
static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE;
|
||||
|
||||
using ElementE = typename MmaTensorOp::ElementE;
|
||||
using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved<kInterleavedE>;
|
||||
|
||||
// Shared memory layout. Interleaved layout is mapped to PitchLinear layout.
|
||||
using SmemLayoutE = typename MmaTensorOp::LayoutE;
|
||||
|
||||
/// ThreadMap of iterator E
|
||||
static int const kElementsPerAccessE =
|
||||
kAccessSizeInBits / sizeof_bits<ElementE>::value;
|
||||
|
||||
/// E is tiny. Not all warps are needed.
|
||||
static int const kThreadsE =
|
||||
(Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
|
||||
(kAccessSizeInBits / sizeof_bits<ElementE>::value) >
|
||||
kThreads)
|
||||
? kThreads
|
||||
: (Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
|
||||
(kAccessSizeInBits / sizeof_bits<ElementE>::value));
|
||||
|
||||
using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kM * kInterleavedE,
|
||||
Shape::kK / kSparse / kElementsPerElementE /
|
||||
kInterleavedE>,
|
||||
kThreadsE, kElementsPerAccessE>;
|
||||
|
||||
/// Shared memory iterator to E operand
|
||||
using SmemIteratorE = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM * kInterleavedE,
|
||||
Shape::kK / kSparse / kElementsPerElementE / kInterleavedE>,
|
||||
ElementE, SmemLayoutE, 0, IteratorThreadMapE>;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy =
|
||||
SparseMmaPolicy<MmaTensorOp, MatrixShape<0, 0>, MatrixShape<0, 0>,
|
||||
MatrixShape<0, 0>, WarpCount::kK>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: row-major
|
||||
/// B: column-major
|
||||
/// Operator: tensor op class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of B operand
|
||||
typename ElementB_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Operation performed by MMA
|
||||
typename Operator_,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultSparseMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
layout::RowMajor, ElementB_, layout::ColumnMajor,
|
||||
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
|
||||
Operator_, false, CacheOpA, CacheOpB> {
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
static int const kSparse = 2;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of a threadblock-scoped access
|
||||
static int const kAccessSizeInBits = 128;
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
// Warp thread arrangement
|
||||
static int const kWarpThreadArrangementContiguousA =
|
||||
Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedA =
|
||||
kWarpSize / kWarpThreadArrangementContiguousA;
|
||||
|
||||
// crosswise cannot be larger than 1024 bit.
|
||||
static int const kCrosswiseB =
|
||||
(Shape::kK > (1024 / sizeof_bits<ElementB>::value))
|
||||
? (1024 / sizeof_bits<ElementB>::value)
|
||||
: Shape::kK;
|
||||
|
||||
static int const kWarpThreadArrangementContiguousB =
|
||||
kCrosswiseB / (kAccessSizeInBits / sizeof_bits<ElementB>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedB =
|
||||
kWarpSize / kWarpThreadArrangementContiguousB;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
sizeof_bits<ElementA>::value, Shape::kK / kSparse>;
|
||||
|
||||
// Shared memory layout
|
||||
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
sizeof_bits<ElementB>::value, kCrosswiseB>;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK / kSparse, Shape::kM>, kThreads,
|
||||
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
|
||||
kWarpThreadArrangementStridedA>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK / kSparse>, ElementA, SmemLayoutA, 0,
|
||||
IteratorThreadMapA>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreads,
|
||||
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
|
||||
kWarpThreadArrangementStridedB>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
|
||||
IteratorThreadMapB>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp<
|
||||
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
|
||||
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
|
||||
|
||||
/// Cache operation of operand E
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpE =
|
||||
cutlass::arch::CacheOperation::Global;
|
||||
|
||||
static int const kInterleavedE = MmaTensorOp::kInterleaved;
|
||||
static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits;
|
||||
static int const kMaxID2 = MmaTensorOp::kMaxID2;
|
||||
static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE;
|
||||
|
||||
using ElementE = typename MmaTensorOp::ElementE;
|
||||
using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved<kInterleavedE>;
|
||||
|
||||
// Shared memory layout. Interleaved layout is mapped to PitchLinear layout.
|
||||
using SmemLayoutE = typename MmaTensorOp::LayoutE;
|
||||
|
||||
/// ThreadMap of iterator E
|
||||
static int const kElementsPerAccessE =
|
||||
kAccessSizeInBits / sizeof_bits<ElementE>::value;
|
||||
|
||||
/// E is tiny. Not all warps are needed.
|
||||
static int const kThreadsE =
|
||||
(Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
|
||||
(kAccessSizeInBits / sizeof_bits<ElementE>::value) >
|
||||
kThreads)
|
||||
? kThreads
|
||||
: (Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
|
||||
(kAccessSizeInBits / sizeof_bits<ElementE>::value));
|
||||
|
||||
using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kM * kInterleavedE,
|
||||
Shape::kK / kSparse / kElementsPerElementE /
|
||||
kInterleavedE>,
|
||||
kThreadsE, kElementsPerAccessE>;
|
||||
|
||||
|
||||
/// Shared memory iterator to E operand
|
||||
using SmemIteratorE = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM * kInterleavedE,
|
||||
Shape::kK / kSparse / kElementsPerElementE / kInterleavedE>,
|
||||
ElementE, SmemLayoutE, 0, IteratorThreadMapE>;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy =
|
||||
SparseMmaPolicy<MmaTensorOp, MatrixShape<0, 0>, MatrixShape<0, 0>,
|
||||
MatrixShape<0, 0>, WarpCount::kK>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: column-major
|
||||
/// B: column-major
|
||||
/// Operator: tensor op class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of B operand
|
||||
typename ElementB_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Operation performed by MMA
|
||||
typename Operator_,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultSparseMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
layout::ColumnMajor, ElementB_, layout::ColumnMajor,
|
||||
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
|
||||
Operator_, false, CacheOpA, CacheOpB> {
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using ElementA = ElementA_;
|
||||
|
||||
using LayoutA = layout::ColumnMajor;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
static int const kSparse = 2;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of a threadblock-scoped access
|
||||
static int const kAccessSizeInBits = 128;
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
// Warp thread arrangement
|
||||
// crosswise cannot be larger than 1024 bit.
|
||||
static int const kCrosswiseB =
|
||||
(Shape::kK > (1024 / sizeof_bits<ElementB>::value))
|
||||
? (1024 / sizeof_bits<ElementB>::value)
|
||||
: Shape::kK;
|
||||
|
||||
static int const kWarpThreadArrangementContiguousB =
|
||||
kCrosswiseB / (kAccessSizeInBits / sizeof_bits<ElementB>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedB =
|
||||
kWarpSize / kWarpThreadArrangementContiguousB;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous<
|
||||
sizeof_bits<ElementA>::value, int(128 / sizeof(ElementA))>;
|
||||
|
||||
// Shared memory layout
|
||||
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
sizeof_bits<ElementB>::value, kCrosswiseB>;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kM, Shape::kK / kSparse>, kThreads,
|
||||
layout::PitchLinearShape<8, 4>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK / kSparse>, ElementA, SmemLayoutA, 1,
|
||||
IteratorThreadMapA>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreads,
|
||||
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
|
||||
kWarpThreadArrangementStridedB>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
|
||||
IteratorThreadMapB>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp<
|
||||
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
|
||||
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
|
||||
|
||||
/// Cache operation of operand E
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpE =
|
||||
cutlass::arch::CacheOperation::Global;
|
||||
|
||||
static int const kInterleavedE = MmaTensorOp::kInterleaved;
|
||||
static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits;
|
||||
static int const kMaxID2 = MmaTensorOp::kMaxID2;
|
||||
static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE;
|
||||
|
||||
using ElementE = typename MmaTensorOp::ElementE;
|
||||
using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved<kInterleavedE>;
|
||||
|
||||
// Shared memory layout. Interleaved layout is mapped to PitchLinear layout.
|
||||
using SmemLayoutE = typename MmaTensorOp::LayoutE;
|
||||
|
||||
/// ThreadMap of iterator E
|
||||
static int const kElementsPerAccessE =
|
||||
kAccessSizeInBits / sizeof_bits<ElementE>::value;
|
||||
|
||||
/// E is tiny. Not all warps are needed.
|
||||
static int const kThreadsE =
|
||||
(Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
|
||||
(kAccessSizeInBits / sizeof_bits<ElementE>::value) >
|
||||
kThreads)
|
||||
? kThreads
|
||||
: (Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
|
||||
(kAccessSizeInBits / sizeof_bits<ElementE>::value));
|
||||
|
||||
using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kM * kInterleavedE,
|
||||
Shape::kK / kSparse / kElementsPerElementE /
|
||||
kInterleavedE>,
|
||||
kThreadsE, kElementsPerAccessE>;
|
||||
|
||||
/// Shared memory iterator to E operand
|
||||
using SmemIteratorE = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM * kInterleavedE,
|
||||
Shape::kK / kSparse / kElementsPerElementE / kInterleavedE>,
|
||||
ElementE, SmemLayoutE, 0, IteratorThreadMapE>;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy =
|
||||
SparseMmaPolicy<MmaTensorOp, MatrixShape<0, 0>, MatrixShape<0, 0>,
|
||||
MatrixShape<0, 0>, WarpCount::kK>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: row-major
|
||||
/// B: row-major
|
||||
/// Operator: tensor op class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of B operand
|
||||
typename ElementB_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Operation performed by MMA
|
||||
typename Operator_,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultSparseMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
layout::RowMajor, ElementB_, layout::RowMajor, ElementC_,
|
||||
LayoutC_, arch::OpClassTensorOp, Stages, Operator_,
|
||||
false, CacheOpA, CacheOpB> {
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = layout::RowMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
static int const kSparse = 2;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of a threadblock-scoped access
|
||||
static int const kAccessSizeInBits = 128;
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
// Warp thread arrangement
|
||||
static int const kWarpThreadArrangementContiguousA =
|
||||
Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedA =
|
||||
kWarpSize / kWarpThreadArrangementContiguousA;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
sizeof_bits<ElementA>::value, Shape::kK / kSparse>;
|
||||
|
||||
// Shared memory layout
|
||||
using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous<
|
||||
sizeof_bits<ElementB>::value, int(128 / sizeof(ElementB))>;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK / kSparse, Shape::kM>, kThreads,
|
||||
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
|
||||
kWarpThreadArrangementStridedA>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK / kSparse>, ElementA, SmemLayoutA, 0,
|
||||
IteratorThreadMapA>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kN, Shape::kK>, kThreads,
|
||||
layout::PitchLinearShape<8, 4>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 0,
|
||||
IteratorThreadMapB>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp<
|
||||
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
|
||||
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
|
||||
|
||||
/// Cache operation of operand E
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpE =
|
||||
cutlass::arch::CacheOperation::Global;
|
||||
|
||||
static int const kInterleavedE = MmaTensorOp::kInterleaved;
|
||||
static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits;
|
||||
static int const kMaxID2 = MmaTensorOp::kMaxID2;
|
||||
static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE;
|
||||
|
||||
using ElementE = typename MmaTensorOp::ElementE;
|
||||
using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved<kInterleavedE>;
|
||||
|
||||
// Shared memory layout. Interleaved layout is mapped to PitchLinear layout.
|
||||
using SmemLayoutE = typename MmaTensorOp::LayoutE;
|
||||
|
||||
/// ThreadMap of iterator E
|
||||
static int const kElementsPerAccessE =
|
||||
kAccessSizeInBits / sizeof_bits<ElementE>::value;
|
||||
|
||||
/// E is tiny. Not all warps are needed.
|
||||
static int const kThreadsE =
|
||||
(Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
|
||||
(kAccessSizeInBits / sizeof_bits<ElementE>::value) >
|
||||
kThreads)
|
||||
? kThreads
|
||||
: (Shape::kM * Shape::kK / kSparse / kElementsPerElementE /
|
||||
(kAccessSizeInBits / sizeof_bits<ElementE>::value));
|
||||
|
||||
using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kM * kInterleavedE,
|
||||
Shape::kK / kSparse / kElementsPerElementE /
|
||||
kInterleavedE>,
|
||||
kThreadsE, kElementsPerAccessE>;
|
||||
|
||||
/// Shared memory iterator to E operand
|
||||
using SmemIteratorE = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM * kInterleavedE,
|
||||
Shape::kK / kSparse / kElementsPerElementE / kInterleavedE>,
|
||||
ElementE, SmemLayoutE, 0, IteratorThreadMapE>;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy =
|
||||
SparseMmaPolicy<MmaTensorOp, MatrixShape<0, 0>, MatrixShape<0, 0>,
|
||||
MatrixShape<0, 0>, WarpCount::kK>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
190
include/cutlass/gemm/threadblock/default_sparse_mma.h
Normal file
190
include/cutlass/gemm/threadblock/default_sparse_mma.h
Normal file
@ -0,0 +1,190 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/wmma.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h"
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_wmma.h"
|
||||
#endif //CUTLASS_ARCH_WMMA_ENABLED
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false
|
||||
>
|
||||
struct DefaultSparseMma;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp)
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultSparseMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape,
|
||||
InstructionShape, Stages, Operator, false> {
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
static int const kSparse = MmaCore::kSparse;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK / kSparse>,
|
||||
ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>;
|
||||
|
||||
// Define iterators over tiles from the E operand
|
||||
using ElementE = typename MmaCore::ElementE;
|
||||
using LayoutE = typename MmaCore::GmemLayoutE;
|
||||
using ThreadMapE = typename MmaCore::IteratorThreadMapE;
|
||||
using AccessTypeE =
|
||||
cutlass::Array<ElementE, 128 / sizeof_bits<ElementE>::value>;
|
||||
using IteratorE =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM,
|
||||
ThreadblockShape::kK / kSparse /
|
||||
MmaCore::kElementsPerElementE>,
|
||||
ElementE, LayoutE, 1, ThreadMapE, AccessTypeE>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::SparseMmaMultistage<
|
||||
typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
|
||||
MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor,
|
||||
IteratorE, typename MmaCore::SmemIteratorE, MmaCore::kCacheOpE,
|
||||
typename MmaCore::MmaPolicy, Stages>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
259
include/cutlass/gemm/threadblock/mma_sparse_base.h
Normal file
259
include/cutlass/gemm/threadblock/mma_sparse_base.h
Normal file
@ -0,0 +1,259 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Policy object describing MmaTensorOp
|
||||
template <
|
||||
/// Warp-level GEMM operator (concept: gemm::warp::Mma)
|
||||
typename Operator_,
|
||||
/// Padding used for A operand in shared memory (concept: MatrixShape)
|
||||
typename SmemPaddingA_,
|
||||
/// Padding used for B operand in shared memory (concept: MatrixShape)
|
||||
typename SmemPaddingB_,
|
||||
/// Padding used for E operand in shared memory (concept: MatrixShape)
|
||||
typename SmemPaddingE_,
|
||||
/// Number of partitions of K dimension of GEMM
|
||||
int PartitionsK = 1>
|
||||
struct SparseMmaPolicy {
|
||||
/// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt)
|
||||
using Operator = Operator_;
|
||||
|
||||
/// Padding used for A operand in shared memory
|
||||
using SmemPaddingA = SmemPaddingA_;
|
||||
|
||||
/// Padding used for B operand in shared memory
|
||||
using SmemPaddingB = SmemPaddingB_;
|
||||
|
||||
/// Padding used for B operand in shared memory
|
||||
using SmemPaddingE = SmemPaddingE_;
|
||||
|
||||
/// Number of partitions of K dimension
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class SparseMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
static int const kSparse = Operator::kSparse;
|
||||
|
||||
static int const kElementsPerElementE = Operator::kElementsPerElementE;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
/// Tensor reference to the E operand
|
||||
using TensorRefE = TensorRef<typename Operator::ElementE, typename Operator::LayoutE>;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK / kSparse * kStages +
|
||||
Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB =
|
||||
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
/// Shape of the E matrix operand in shared memory
|
||||
using ShapeE =
|
||||
MatrixShape<Shape::kM * 2 + Policy::SmemPaddingE::kRow,
|
||||
Shape::kK / kSparse / kElementsPerElementE / 2 * kStages +
|
||||
Policy::SmemPaddingE::kColumn>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer for E operand
|
||||
AlignedBuffer<typename Operator::ElementE, ShapeE::kCount> operand_E;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the E matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutE LayoutE() {
|
||||
return Operator::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() {
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the E operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefE operand_E_ref() {
|
||||
return TensorRefE{operand_E.data(), LayoutE()};
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of E operand from shared memory
|
||||
typename Operator::IteratorE warp_tile_iterator_E_;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
SparseMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx),
|
||||
warp_tile_iterator_E_(shared_storage.operand_E_ref(), lane_idx) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
667
include/cutlass/gemm/threadblock/mma_sparse_multistage.h
Normal file
667
include/cutlass/gemm/threadblock/mma_sparse_multistage.h
Normal file
@ -0,0 +1,667 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/mma_sparse_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Iterates over tiles of E operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorE_,
|
||||
/// Iterates over tiles of E operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorE_,
|
||||
/// Cache operation for operand E
|
||||
cutlass::arch::CacheOperation::Kind CacheOpE,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class SparseMmaMultistage :
|
||||
public SparseMmaBase<Shape_, Policy_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = SparseMmaBase<Shape_, Policy_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Iterates over tiles of E operand in global memory
|
||||
using IteratorE = IteratorE_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorE = SmemIteratorE_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpE = CacheOpE;
|
||||
|
||||
static int const kSparse = Policy::Operator::kSparse;
|
||||
static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits;
|
||||
static int const kMaxID2 = Policy::Operator::kMaxID2;
|
||||
static int const kElementsPerElementE =
|
||||
Policy::Operator::kElementsPerElementE;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// ElementE
|
||||
using ElementE = typename IteratorE::Element;
|
||||
|
||||
/// LayoutE
|
||||
using LayoutE = typename IteratorE::Layout;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of async copies to load one stage of operand A
|
||||
static int const TBLDGSTSIterationsA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of async copies to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of async copies to load one stage of operand E
|
||||
static int const TBLDGSTSIterationsE =
|
||||
IteratorE::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of async copies to load one group of operand A
|
||||
static int const kAccessesPerGroupA =
|
||||
(TBLDGSTSIterationsA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of async copies to load one group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(TBLDGSTSIterationsB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of async copies to load one group of operand E
|
||||
static int const kAccessesPerGroupE =
|
||||
(TBLDGSTSIterationsE + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// E operand is tiny. For the most of time, not all the warps are needed
|
||||
/// to load it from the global memory.
|
||||
static int const kValidWarps = IteratorE::ThreadMap::kThreads / 32;
|
||||
|
||||
/// B operand is twice as big as A which brings very high register pressure.
|
||||
/// We have to sacrifice the double buffer when the warp tile size is big.
|
||||
static int const kBBufferSize =
|
||||
((sizeof(typename Operator::ElementC) == 4) &&
|
||||
((platform::is_same<typename Operator::Policy::Operator::ElementA,
|
||||
typename Operator::ElementA>::value &&
|
||||
platform::is_same<typename Operator::Policy::Operator::ElementB,
|
||||
typename Operator::ElementB>::value)) &&
|
||||
(Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64))
|
||||
? 1
|
||||
: 2;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
using WarpFragmentE = typename Operator::FragmentE;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of E operand to shared memory
|
||||
SmemIteratorE smem_iterator_E_;
|
||||
|
||||
/// Warp id
|
||||
bool is_warp_valid_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
SparseMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
||||
smem_iterator_E_(shared_storage.operand_E_ref(), thread_idx)
|
||||
{
|
||||
is_warp_valid_ = warp_idx < Detail::kValidWarps;
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
this->warp_tile_iterator_E_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B,
|
||||
IteratorE &iterator_E, int group_start_A = 0,
|
||||
int group_start_B = 0, int group_start_E = 0) {
|
||||
iterator_A.set_iteration_index(group_start_A *
|
||||
IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// async copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
||||
if (group_start_A + j < Detail::TBLDGSTSIterationsA) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// async copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::TBLDGSTSIterationsB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_E.set_iteration_index(group_start_E);
|
||||
this->smem_iterator_E_.set_iteration_index(group_start_E);
|
||||
|
||||
// async copy for operand E
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupE; ++j) {
|
||||
if (group_start_E + j < Detail::TBLDGSTSIterationsE) {
|
||||
typename IteratorE::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorE::AccessType *>(
|
||||
this->smem_iterator_E_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorE::Element>::value *
|
||||
IteratorE::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
auto gmem_ptr = iterator_E.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpE>(
|
||||
dst_ptr, gmem_ptr, iterator_E.valid() && is_warp_valid_);
|
||||
|
||||
++iterator_E;
|
||||
++this->smem_iterator_E_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< iterator over E operand in global memory
|
||||
IteratorE iterator_E,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
iterator_E.clear_mask();
|
||||
}
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// async copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsA; ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// async copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
iterator_E.set_iteration_index(0);
|
||||
this->smem_iterator_E_.set_iteration_index(0);
|
||||
|
||||
// async copy for operand E
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsE; ++j) {
|
||||
typename IteratorE::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorE::AccessType *>(
|
||||
this->smem_iterator_E_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorE::Element>::value *
|
||||
IteratorE::ThreadMap::kElementsPerAccess / 8;
|
||||
if (is_warp_valid_)
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpE>(
|
||||
dst_ptr, iterator_E.get(), iterator_E.valid());
|
||||
|
||||
++iterator_E;
|
||||
|
||||
++this->smem_iterator_E_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
iterator_E.add_tile_offset({0, 1});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
this->smem_iterator_E_.add_tile_offset({0, 1});
|
||||
|
||||
// LDGDEPBAR - completes a stage
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A[2];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B[Detail::kBBufferSize];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B[Detail::kBBufferSize];
|
||||
WarpFragmentE warp_frag_E[2];
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_E_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
|
||||
this->warp_tile_iterator_E_.load(warp_frag_E[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
++this->warp_tile_iterator_E_;
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
iterator_E.clear_mask();
|
||||
}
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_E_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_E_.load(warp_frag_E[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_E_;
|
||||
|
||||
if (Detail::kBBufferSize == 2) {
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.load(
|
||||
warp_loaded_frag_B[(warp_mma_k + 1) % Detail::kBBufferSize]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B[warp_mma_k % Detail::kBBufferSize]);
|
||||
|
||||
warp_mma(
|
||||
accum,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize], accum,
|
||||
warp_frag_E[warp_mma_k % 2]
|
||||
);
|
||||
|
||||
if (Detail::kBBufferSize == 1) {
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
}
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
|
||||
int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E;
|
||||
|
||||
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
group_start_iteration_E = warp_mma_k * Detail::kAccessesPerGroupE;
|
||||
|
||||
copy_tiles_and_advance(
|
||||
iterator_A, iterator_B, iterator_E, group_start_iteration_A,
|
||||
group_start_iteration_B, group_start_iteration_E);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E;
|
||||
group_start_iteration_A =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
group_start_iteration_E =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupE;
|
||||
|
||||
copy_tiles_and_advance(
|
||||
iterator_A, iterator_B, iterator_E, group_start_iteration_A,
|
||||
group_start_iteration_B, group_start_iteration_E);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
iterator_E.add_tile_offset({0, 1});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
this->smem_iterator_E_.add_tile_offset({0, 1});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_E_.add_tile_offset({0, -Base::kStages});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations,
|
||||
0});
|
||||
this->warp_tile_iterator_E_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
iterator_E.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations)
|
||||
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -123,16 +123,22 @@ struct GemmIdentityThreadblockSwizzle {
|
||||
/// Computes CUDA grid dimensions given a size in units of logical tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
dim3 get_grid_shape(GemmCoord tiled_shape) const {
|
||||
if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile))
|
||||
return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
|
||||
|
||||
return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k());
|
||||
}
|
||||
|
||||
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
|
||||
CUTLASS_DEVICE
|
||||
GemmCoord get_tile_offset() const {
|
||||
GemmCoord get_tile_offset(GemmCoord tiled_shape) const {
|
||||
|
||||
int block_idx_x = RematerializeBlockIdxX();
|
||||
int block_idx_y = RematerializeBlockIdxY();
|
||||
|
||||
if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile))
|
||||
return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()};
|
||||
|
||||
return GemmCoord{
|
||||
(block_idx_x / kTile),
|
||||
(block_idx_y * kTile) + (block_idx_x % kTile),
|
||||
@ -170,7 +176,7 @@ struct GemmHorizontalThreadblockSwizzle {
|
||||
|
||||
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
|
||||
CUTLASS_DEVICE
|
||||
GemmCoord get_tile_offset() const {
|
||||
GemmCoord get_tile_offset(GemmCoord tiled_shape) const {
|
||||
return GemmCoord{
|
||||
RematerializeBlockIdxY(),
|
||||
RematerializeBlockIdxX(),
|
||||
@ -205,7 +211,7 @@ struct GemmBatchedIdentityThreadblockSwizzle {
|
||||
|
||||
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
|
||||
CUTLASS_DEVICE
|
||||
GemmCoord get_tile_offset() const {
|
||||
GemmCoord get_tile_offset(GemmCoord tiled_shape) const {
|
||||
return GemmCoord{
|
||||
RematerializeBlockIdxX(),
|
||||
RematerializeBlockIdxY(),
|
||||
@ -244,17 +250,23 @@ struct GemmSplitKIdentityThreadblockSwizzle {
|
||||
/// Computes CUDA grid dimensions given a size in units of logical tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
dim3 get_grid_shape(GemmCoord tiled_shape) const {
|
||||
if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile))
|
||||
return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
|
||||
|
||||
return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k());
|
||||
}
|
||||
|
||||
|
||||
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
|
||||
CUTLASS_DEVICE
|
||||
GemmCoord get_tile_offset() const {
|
||||
GemmCoord get_tile_offset(GemmCoord tiled_shape) const {
|
||||
|
||||
int block_idx_x = RematerializeBlockIdxX();
|
||||
int block_idx_y = RematerializeBlockIdxY();
|
||||
|
||||
if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile))
|
||||
return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()};
|
||||
|
||||
return GemmCoord{
|
||||
(block_idx_x / kTile),
|
||||
(block_idx_y * kTile) + (block_idx_x % kTile),
|
||||
@ -290,7 +302,7 @@ struct GemmSplitKHorizontalThreadblockSwizzle {
|
||||
|
||||
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
|
||||
CUTLASS_DEVICE
|
||||
GemmCoord get_tile_offset() const {
|
||||
GemmCoord get_tile_offset(GemmCoord tiled_shape) const {
|
||||
return GemmCoord{
|
||||
RematerializeBlockIdxY(),
|
||||
RematerializeBlockIdxX(),
|
||||
|
||||
159
include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h
Normal file
159
include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h
Normal file
@ -0,0 +1,159 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/warp/mma_sparse_tensor_op.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Data type of B elements
|
||||
typename ElementB_,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Operator describing the tensor operation
|
||||
typename Operator_ = arch::OpMultiplyAdd,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK = 1,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false
|
||||
>
|
||||
struct DefaultSparseMmaTensorOp;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial Specialization - inputs and output types are float - uses TF32 internally
|
||||
template <
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of target matrix multiply instruction (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor>
|
||||
struct DefaultSparseMmaTensorOp<
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
float, LayoutA,
|
||||
float, LayoutB,
|
||||
float, LayoutC,
|
||||
arch::OpMultiplyAdd, PartitionsK, AccumulatorsInRowMajor> {
|
||||
|
||||
// Uses TF32 internally
|
||||
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
|
||||
cutlass::arch::SparseMma<
|
||||
InstructionShape_,
|
||||
32,
|
||||
tfloat32_t, cutlass::layout::RowMajor,
|
||||
tfloat32_t, cutlass::layout::ColumnMajor,
|
||||
float, cutlass::layout::RowMajor,
|
||||
arch::OpMultiplyAdd
|
||||
>,
|
||||
cutlass::MatrixShape<1, 1> >;
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using Type = cutlass::gemm::warp::SparseMmaTensorOp<
|
||||
WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC,
|
||||
Policy, PartitionsK, AccumulatorsInRowMajor>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for m-by-n-by-kgroup
|
||||
template <
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA,
|
||||
/// Data type of B elements
|
||||
typename ElementB,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Operator describing the tensor operation
|
||||
typename Operator_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor>
|
||||
struct DefaultSparseMmaTensorOp {
|
||||
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
|
||||
cutlass::arch::SparseMma<InstructionShape_, 32, ElementA,
|
||||
cutlass::layout::RowMajor, ElementB,
|
||||
cutlass::layout::ColumnMajor, ElementC,
|
||||
cutlass::layout::RowMajor, Operator_>,
|
||||
cutlass::MatrixShape<1, 1> >;
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using Type = cutlass::gemm::warp::SparseMmaTensorOp<
|
||||
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
Policy, PartitionsK, AccumulatorsInRowMajor>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -147,6 +147,9 @@ public:
|
||||
dp4a_type
|
||||
>;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename ThreadMma::ArchMmaOperator;
|
||||
|
||||
/// Shape of the underlying instruction
|
||||
using InstructionShape = GemmShape<1,1,use_dp4a ? 4 : 1>;
|
||||
|
||||
|
||||
335
include/cutlass/gemm/warp/mma_sparse_tensor_op.h
Normal file
335
include/cutlass/gemm/warp/mma_sparse_tensor_op.h
Normal file
@ -0,0 +1,335 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing warp-level matrix multiply-accumulate
|
||||
operations targeting sparse Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm80.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Data type of B elements
|
||||
typename ElementB_,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
typename Policy_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_ = 1,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class SparseMmaTensorOp {
|
||||
public:
|
||||
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Data type of multiplicand A
|
||||
using ElementA = ElementA_;
|
||||
|
||||
/// Layout of multiplicand A
|
||||
using LayoutA = LayoutA_;
|
||||
|
||||
/// Data type of multiplicand B
|
||||
using ElementB = ElementB_;
|
||||
|
||||
/// Layout of multiplicand B
|
||||
using LayoutB = LayoutB_;
|
||||
|
||||
/// Data type of accumulator matrix C
|
||||
using ElementC = ElementC_;
|
||||
|
||||
/// Layout of accumulator matrix C
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Architecture tag from underlying instruction
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
/// Number of partitions along K dimension
|
||||
static int const kPartitionsK = PartitionsK_;
|
||||
|
||||
/// Sparsity in Operand A
|
||||
static int const kSparse = Policy::Operator::kSparse;
|
||||
|
||||
/// Meta data size in bits
|
||||
static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits;
|
||||
|
||||
/// Max ID2
|
||||
static int const kMaxID2 = Policy::Operator::kMaxID2;
|
||||
|
||||
/// Data type of meta E that is moved at the same time
|
||||
using ElementE =
|
||||
typename cutlass::platform::conditional<kMaxID2 == 1, uint32_t,
|
||||
uint16_t>::type;
|
||||
|
||||
/// Number of ElementA that is associated with one ElementE
|
||||
static int const kElementsPerElementE =
|
||||
128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
/// Meta data is essentially interleaved but mapped to ColumnMajor internally
|
||||
static int const kInterleaved = 2;
|
||||
|
||||
/// Layout of meta E
|
||||
using LayoutE = cutlass::layout::ColumnMajor;
|
||||
|
||||
public:
|
||||
|
||||
/// Iterates over the A operand in memory
|
||||
using IteratorA = MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK / kSparse>, Operand::kA, ElementA,
|
||||
LayoutA,
|
||||
MatrixShape<Policy::Operator::Shape::kM,
|
||||
Policy::Operator::Shape::kK / kSparse>,
|
||||
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for A tile
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA =
|
||||
Array<typename Policy::Operator::ElementA, FragmentA::kElements>;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
|
||||
MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
|
||||
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Storage for transformed B tile
|
||||
using TransformedFragmentB =
|
||||
Array<typename Policy::Operator::ElementB, FragmentB::kElements>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaTensorOpAccumulatorTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
|
||||
typename Policy::Operator::Shape, typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
/// Iterates over the E operand in memory
|
||||
using IteratorE = SparseMmaTensorOpMetaTileIterator<
|
||||
MatrixShape<Shape::kM * kInterleaved,
|
||||
Shape::kK / kSparse / kElementsPerElementE / kInterleaved>,
|
||||
ElementE, LayoutE,
|
||||
MatrixShape<Policy::Operator::Shape::kM,
|
||||
Policy::Operator::Shape::kK / kSparse / kElementsPerElementE /
|
||||
kInterleaved>,
|
||||
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for E tile
|
||||
using FragmentE = typename IteratorE::Fragment;
|
||||
|
||||
private:
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % Policy::Operator::Shape::kM) &&
|
||||
!(Shape::kN % Policy::Operator::Shape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<
|
||||
Shape::kM / Policy::Operator::Shape::kM,
|
||||
Shape::kN / Policy::Operator::Shape::kN
|
||||
>;
|
||||
|
||||
public:
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
typename Policy::Operator mma;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
SparseMmaTensorOp() {}
|
||||
|
||||
/// Performs a warp-level matrix multiply-accumulate operation
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
FragmentC &D,
|
||||
TransformedFragmentA const &A,
|
||||
TransformedFragmentB const &B,
|
||||
FragmentC const &C,
|
||||
FragmentE const &E
|
||||
) const {
|
||||
|
||||
using MmaOperandA = typename Policy::Operator::FragmentA;
|
||||
using MmaOperandB = typename Policy::Operator::FragmentB;
|
||||
using MmaOperandC = typename Policy::Operator::FragmentC;
|
||||
using MmaOperandE = typename Policy::Operator::FragmentE;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
D = C;
|
||||
|
||||
MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);
|
||||
MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
|
||||
MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);
|
||||
MmaOperandE const *ptr_E = reinterpret_cast<MmaOperandE const *>(&E);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m) {
|
||||
|
||||
int id2 = m % kMaxID2;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n) {
|
||||
|
||||
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
|
||||
|
||||
if (AccumulatorsInRowMajor) { // matrix B is reordered
|
||||
mma(
|
||||
ptr_D[n_serpentine + m * MmaIterations::kColumn],
|
||||
ptr_A[m],
|
||||
ptr_B[n_serpentine],
|
||||
ptr_D[n_serpentine + m * MmaIterations::kColumn],
|
||||
ptr_E[(m / kMaxID2)],
|
||||
id2);
|
||||
} else {
|
||||
mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
|
||||
ptr_A[m],
|
||||
ptr_B[n_serpentine],
|
||||
ptr_D[m + n_serpentine * MmaIterations::kRow],
|
||||
ptr_E[(m / kMaxID2)],
|
||||
id2);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Transform the mma operands to the required types
|
||||
CUTLASS_DEVICE
|
||||
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
|
||||
FragmentA const &A, FragmentB const &B) const {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
//
|
||||
// Define conversions from source type to instruction type
|
||||
//
|
||||
FloatRoundStyle const kRoundA =
|
||||
PreferredRoundingMode<typename Policy::Operator::ElementA,
|
||||
ElementA>::kRound;
|
||||
FloatRoundStyle const kRoundB =
|
||||
PreferredRoundingMode<typename Policy::Operator::ElementB,
|
||||
ElementB>::kRound;
|
||||
detail::ConvertAndPack<typename Policy::Operator::ElementA, ElementA,
|
||||
FragmentA::kElements / 2, kRoundA>
|
||||
convert_A;
|
||||
NumericArrayConverter<typename Policy::Operator::ElementB, ElementB,
|
||||
FragmentB::kElements, kRoundB>
|
||||
convert_B;
|
||||
Array<ElementA, FragmentA::kElements / 2> const *ptr_A =
|
||||
reinterpret_cast<Array<ElementA, FragmentA::kElements / 2> const *>(&A);
|
||||
Array<typename Policy::Operator::ElementA, FragmentA::kElements / 2> *
|
||||
ptr_dst_A = reinterpret_cast<Array<typename Policy::Operator::ElementA,
|
||||
FragmentA::kElements / 2> *>(&dst_A);
|
||||
|
||||
dst_B = convert_B(B);
|
||||
|
||||
ptr_dst_A[0] = convert_A(ptr_A[0]);
|
||||
ptr_dst_A[1] = convert_A(ptr_A[1]);
|
||||
#else
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -184,14 +184,17 @@ public:
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Policy::Operator;
|
||||
|
||||
/// Architecture tag from underlying instruction
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
using ArchTag = typename ArchMmaOperator::ArchTag;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename Policy::Operator::Shape;
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
@ -210,7 +213,7 @@ public:
|
||||
/// Iterates over the A operand in memory
|
||||
using IteratorA = MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
|
||||
MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,
|
||||
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
|
||||
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for A tile
|
||||
@ -218,12 +221,12 @@ public:
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA =
|
||||
Array<typename Policy::Operator::ElementA, FragmentA::kElements>;
|
||||
Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
|
||||
MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
|
||||
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
|
||||
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for B tile
|
||||
@ -231,33 +234,28 @@ public:
|
||||
|
||||
/// Storage for transformed B tile
|
||||
using TransformedFragmentB =
|
||||
Array<typename Policy::Operator::ElementB, FragmentB::kElements>;
|
||||
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaTensorOpAccumulatorTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
|
||||
typename Policy::Operator::Shape, typename Policy::OpDelta>;
|
||||
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
private:
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % Policy::Operator::Shape::kM) &&
|
||||
!(Shape::kN % Policy::Operator::Shape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<
|
||||
Shape::kM / Policy::Operator::Shape::kM,
|
||||
Shape::kN / Policy::Operator::Shape::kN
|
||||
(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
|
||||
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN
|
||||
>;
|
||||
|
||||
public:
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
typename Policy::Operator mma;
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
|
||||
@ -278,9 +276,9 @@ public:
|
||||
FragmentC const &C
|
||||
) const {
|
||||
|
||||
using MmaOperandA = typename Policy::Operator::FragmentA;
|
||||
using MmaOperandB = typename Policy::Operator::FragmentB;
|
||||
using MmaOperandC = typename Policy::Operator::FragmentC;
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
D = C;
|
||||
|
||||
@ -351,22 +349,22 @@ public:
|
||||
// Define conversions from source type to instruction type
|
||||
//
|
||||
FloatRoundStyle const kRoundA =
|
||||
PreferredRoundingMode<typename Policy::Operator::ElementA,
|
||||
PreferredRoundingMode<typename ArchMmaOperator::ElementA,
|
||||
ElementA>::kRound;
|
||||
FloatRoundStyle const kRoundB =
|
||||
PreferredRoundingMode<typename Policy::Operator::ElementB,
|
||||
PreferredRoundingMode<typename ArchMmaOperator::ElementB,
|
||||
ElementB>::kRound;
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
detail::ConvertAndPack<typename Policy::Operator::ElementA, ElementA,
|
||||
detail::ConvertAndPack<typename ArchMmaOperator::ElementA, ElementA,
|
||||
FragmentA::kElements, kRoundA>
|
||||
convert_A;
|
||||
NumericArrayConverter<typename Policy::Operator::ElementB, ElementB,
|
||||
NumericArrayConverter<typename ArchMmaOperator::ElementB, ElementB,
|
||||
FragmentB::kElements / 2, kRoundB>
|
||||
convert_B;
|
||||
Array<ElementB, FragmentB::kElements / 2> const *ptr_B =
|
||||
reinterpret_cast<Array<ElementB, FragmentB::kElements / 2> const *>(&B);
|
||||
Array<typename Policy::Operator::ElementB, FragmentB::kElements / 2> *
|
||||
ptr_dst_B = reinterpret_cast<Array<typename Policy::Operator::ElementB,
|
||||
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements / 2> *
|
||||
ptr_dst_B = reinterpret_cast<Array<typename ArchMmaOperator::ElementB,
|
||||
FragmentB::kElements / 2> *>(&dst_B);
|
||||
|
||||
dst_A = convert_A(A);
|
||||
@ -375,16 +373,16 @@ public:
|
||||
ptr_dst_B[1] = convert_B(ptr_B[1]);
|
||||
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
detail::ConvertAndPack<typename Policy::Operator::ElementA, ElementA,
|
||||
detail::ConvertAndPack<typename ArchMmaOperator::ElementA, ElementA,
|
||||
FragmentA::kElements / 2, kRoundA>
|
||||
convert_A;
|
||||
NumericArrayConverter<typename Policy::Operator::ElementB, ElementB,
|
||||
NumericArrayConverter<typename ArchMmaOperator::ElementB, ElementB,
|
||||
FragmentB::kElements, kRoundB>
|
||||
convert_B;
|
||||
Array<ElementA, FragmentA::kElements / 2> const *ptr_A =
|
||||
reinterpret_cast<Array<ElementA, FragmentA::kElements / 2> const *>(&A);
|
||||
Array<typename Policy::Operator::ElementA, FragmentA::kElements / 2> *
|
||||
ptr_dst_A = reinterpret_cast<Array<typename Policy::Operator::ElementA,
|
||||
Array<typename ArchMmaOperator::ElementA, FragmentA::kElements / 2> *
|
||||
ptr_dst_A = reinterpret_cast<Array<typename ArchMmaOperator::ElementA,
|
||||
FragmentA::kElements / 2> *>(&dst_A);
|
||||
|
||||
dst_B = convert_B(B);
|
||||
|
||||
@ -291,7 +291,9 @@ class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Ele
|
||||
!(Shape::kColumn % InstructionShape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
static_assert(
|
||||
!(AccumulatorShape::kRow % Shape::kRow) &&
|
||||
AccumulatorShape::kRow == Shape::kRow,
|
||||
"Rows of Warp Accumulator must be the same as rows of warp");
|
||||
static_assert(
|
||||
!(AccumulatorShape::kColumn % Shape::kColumn),
|
||||
"Shape of Warp Accumulator must be divisible by warp shape.");
|
||||
static_assert(
|
||||
@ -304,7 +306,16 @@ class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Ele
|
||||
|
||||
private:
|
||||
|
||||
static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads;
|
||||
static int const kRowsPerIteration = 8;
|
||||
static int const kColumnsPerIteration = 16;
|
||||
static int const kElementsPerIteration = kRowsPerIteration * InstructionShape::kN / kThreads;
|
||||
static int const kElementsPerAccess = kRowsPerIteration * kColumnsPerIteration / kThreads;
|
||||
static int const kIterationsPerAccess = kElementsPerAccess / kElementsPerIteration;
|
||||
|
||||
// Number of iterations per actual instruction
|
||||
static int const kIterationsPerInstruction = InstructionShape::kM / kRowsPerIteration;
|
||||
|
||||
static int const kAccessStride = kIterationsPerInstruction;
|
||||
|
||||
/// Number of mma operations performed by a warp
|
||||
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
|
||||
@ -313,13 +324,15 @@ private:
|
||||
using AccumulatorIterations = MatrixShape<AccumulatorShape::kRow / InstructionShape::kM,
|
||||
AccumulatorShape::kColumn / InstructionShape::kN>;
|
||||
|
||||
/// Number of Accesses in a warp
|
||||
using AccessIterations = MatrixShape<MmaIterations::kRow * kIterationsPerInstruction,
|
||||
MmaIterations::kColumn / kIterationsPerAccess>;
|
||||
|
||||
/// Number of K iterations
|
||||
static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn;
|
||||
static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn;
|
||||
static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn
|
||||
* (AccumulatorShape::kRow / Shape::kRow);
|
||||
static int const kResidualIndex = kResidualColumn / Shape::kColumn
|
||||
* (AccumulatorShape::kRow / Shape::kRow);
|
||||
static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn;
|
||||
static int const kResidualIndex = kResidualColumn / Shape::kColumn;
|
||||
|
||||
public:
|
||||
|
||||
@ -338,8 +351,8 @@ public:
|
||||
private:
|
||||
|
||||
/// Internal access type
|
||||
using AccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentAccessType = Array<Element, kElementsPerAccess>;
|
||||
using AccessType = Array<ElementAccumulator, kElementsPerIteration>;
|
||||
using FragmentAccessType = Array<Element, kElementsPerIteration>;
|
||||
|
||||
private:
|
||||
//
|
||||
@ -386,6 +399,11 @@ public:
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_index(int idx) {
|
||||
index_ = idx;
|
||||
}
|
||||
|
||||
/// Loads a fragment from the referenced part of the accumulator tile
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag, OutputOp output_op) const {
|
||||
@ -399,21 +417,35 @@ public:
|
||||
FragmentAccessType *frag_ptr = reinterpret_cast<FragmentAccessType *>(&frag);
|
||||
// NumericArrayConverter<Element, ElementAccumulator, kElementsPerAccess, FloatRoundStyle::round_indeterminate> fragmentConverter;
|
||||
|
||||
int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow;
|
||||
int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow
|
||||
* MmaIterations::kColumn;
|
||||
int index = index_ * AccessIterations::kCount;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; m++) {
|
||||
for (int n = 0; n < MmaIterations::kColumn; n++) {
|
||||
int accumulator_access_offset =
|
||||
(m + index_m) * AccumulatorIterations::kColumn + n + index_n;
|
||||
for (int i = 0; i < AccessIterations::kCount; i++) {
|
||||
// int index_m = (index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction)
|
||||
// * kIterationsPerInstruction + index % kIterationsPerInstruction;
|
||||
//
|
||||
// int index_n = (index / AccessIterations::kCount) * MmaIterations::kColumn +
|
||||
// (index % (AccessIterations::kColumn * kIterationsPerInstruction))
|
||||
// / kIterationsPerInstruction * AccessIterations::kColumn;
|
||||
//
|
||||
// int accumulator_access_offset = index_m / kIterationsPerInstruction * AccessIterations::kCount * kIterationsPerInstruction
|
||||
// + index_m % kIterationsPerInstruction + index_n * kIterationsPerInstruction;
|
||||
|
||||
frag_ptr[m * MmaIterations::kColumn + n].clear();
|
||||
int accumulator_access_offset = index / AccessIterations::kCount * (MmaIterations::kColumn * kIterationsPerInstruction) +
|
||||
(index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction) *
|
||||
AccumulatorIterations::kColumn * kIterationsPerInstruction +
|
||||
(index % (AccessIterations::kColumn * kIterationsPerInstruction)) / kIterationsPerInstruction *
|
||||
(kIterationsPerInstruction * kIterationsPerAccess) +
|
||||
(index % kIterationsPerInstruction);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kIterationsPerAccess; j++) {
|
||||
|
||||
frag_ptr[i*kIterationsPerAccess + j].clear();
|
||||
if(!(is_residual_tile_ && index_ >= kResidualIndex))
|
||||
// frag_ptr[m * MmaIterations::kColumn + n] = fragmentConverter(accumulators_[accumulator_access_offset]);
|
||||
frag_ptr[m * MmaIterations::kColumn + n] = output_op(accumulators_[accumulator_access_offset], src_fragment);
|
||||
// frag_ptr[m * MmaIterations::kColumn + n] = fragmentConverter(accumulators_[accumulator_access_offset]);
|
||||
frag_ptr[i*kIterationsPerAccess + j] = output_op(accumulators_[accumulator_access_offset + j * kAccessStride], src_fragment);
|
||||
}
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -106,8 +106,11 @@ public:
|
||||
/// Architecture tag
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Policy::Operator;
|
||||
|
||||
/// Underlying instruction shape
|
||||
using InstructionShape = typename Policy::Operator::Shape;
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
@ -133,8 +136,8 @@ public:
|
||||
ElementA,
|
||||
LayoutA,
|
||||
MatrixShape<
|
||||
Policy::Operator::Shape::kM,
|
||||
Policy::Operator::Shape::kK
|
||||
ArchMmaOperator::Shape::kM,
|
||||
ArchMmaOperator::Shape::kK
|
||||
>,
|
||||
Policy::OpDelta::kRow,
|
||||
kThreadCount
|
||||
@ -150,8 +153,8 @@ public:
|
||||
ElementB,
|
||||
LayoutB,
|
||||
MatrixShape<
|
||||
Policy::Operator::Shape::kK,
|
||||
Policy::Operator::Shape::kN
|
||||
ArchMmaOperator::Shape::kK,
|
||||
ArchMmaOperator::Shape::kN
|
||||
>,
|
||||
Policy::OpDelta::kRow,
|
||||
kThreadCount
|
||||
@ -165,7 +168,7 @@ public:
|
||||
MatrixShape<Shape::kM, Shape::kN>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
typename Policy::Operator::Shape,
|
||||
typename ArchMmaOperator::Shape,
|
||||
typename Policy::OpDelta
|
||||
>;
|
||||
|
||||
@ -175,14 +178,14 @@ public:
|
||||
private:
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % Policy::Operator::Shape::kM) &&
|
||||
!(Shape::kN % Policy::Operator::Shape::kN),
|
||||
!(Shape::kM % ArchMmaOperator::Shape::kM) &&
|
||||
!(Shape::kN % ArchMmaOperator::Shape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<
|
||||
InterleavedTileShape::kM / Policy::Operator::Shape::kM,
|
||||
InterleavedTileShape::kN / Policy::Operator::Shape::kN
|
||||
InterleavedTileShape::kM / ArchMmaOperator::Shape::kM,
|
||||
InterleavedTileShape::kN / ArchMmaOperator::Shape::kN
|
||||
>;
|
||||
using TileIterations = MatrixShape<
|
||||
Shape::kM / InterleavedTileShape::kM,
|
||||
@ -195,7 +198,7 @@ private:
|
||||
public:
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
typename Policy::Operator mma;
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
|
||||
@ -215,9 +218,9 @@ public:
|
||||
FragmentB const &B,
|
||||
FragmentC const &C) {
|
||||
|
||||
using MmaOperandA = typename Policy::Operator::FragmentA;
|
||||
using MmaOperandB = typename Policy::Operator::FragmentB;
|
||||
using MmaOperandC = typename Policy::Operator::FragmentC;
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
D = C;
|
||||
|
||||
|
||||
@ -241,6 +241,9 @@ public:
|
||||
int access_strided_idx = -1;
|
||||
|
||||
if (Policy::LdsmShape::kContiguous == 4) {
|
||||
// Matrix multiply 1688 A/B
|
||||
// Q0 Q1 Q2 Q3 (Q stands for 1 8x128bit block).
|
||||
// Four blocks are next to each other in the contiguous dimension.
|
||||
partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ i);
|
||||
access_contiguous_idx = (quad_pair ^ lane_in_quad);
|
||||
access_strided_idx = lane_in_quad_pair;
|
||||
@ -262,7 +265,17 @@ public:
|
||||
partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1));
|
||||
access_contiguous_idx = ((quad_quad + ((i & 1) << 1)) ^ lane_in_quad);
|
||||
access_strided_idx = lane_in_quad_quad;
|
||||
} else if (Policy::LdsmShape::kContiguous == 1) {
|
||||
// Matrix multiply 16832.SP B
|
||||
// Q0
|
||||
// Q1
|
||||
// Q2
|
||||
// Q3
|
||||
partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 2));
|
||||
access_contiguous_idx = ((i & 3) ^ lane_in_quad);
|
||||
access_strided_idx = lane_id;
|
||||
}
|
||||
|
||||
int access_contiguous =
|
||||
partition_contiguous_idx * Layout::PartitionShape::kContiguous +
|
||||
access_contiguous_idx;
|
||||
@ -531,24 +544,24 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
!(Shape::kContiguous % InstructionShape::kContiguous),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
// Determine number of elements along outer dimension per individual LDS.32
|
||||
// op. Every one warp of LDS.32 loads 8x4 elements
|
||||
// Determine number of elements along outer dimension per individual 32bit
|
||||
// shared memory load op. Every one warp of 32bit shared memory load loads
|
||||
// 8x4 elements
|
||||
static int const kLdsOpInner = Layout::TileShape::kStrided;
|
||||
static int const kLdsOpOuter = kThreads / kLdsOpInner;
|
||||
|
||||
static_assert(!(Shape::kContiguous % kLdsOpOuter),
|
||||
"Shape of warp-level mma must be divisible by LDS.32's "
|
||||
"Shape of warp-level mma must be divisible by 32bit "
|
||||
"fundamental tile size.");
|
||||
|
||||
static_assert(!(Shape::kStrided % kLdsOpInner),
|
||||
"Shape of warp-level mma must be divisible by LDS.32's "
|
||||
"Shape of warp-level mma must be divisible by 32bit "
|
||||
"fundamental tile size.");
|
||||
|
||||
/// Number of LDS.32 instructions needed by one MMA instruction
|
||||
/// 1684 A 2x1
|
||||
/// 1684 B 1x1
|
||||
/// 1688 A 2x2
|
||||
/// 1688 B 1x2
|
||||
/// Number of 32 bit shared memory load instructions needed by one MMA instruction
|
||||
/// 1688 A 2x2
|
||||
/// 1688 B 1x2
|
||||
/// 16816 B 1x4
|
||||
static int const LdsShapeContiguous =
|
||||
InstructionShape::kContiguous / kLdsOpOuter;
|
||||
static int const LdsShapeStrided = InstructionShape::kStrided / kLdsOpInner;
|
||||
@ -639,6 +652,8 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
if (Shape::kContiguous ==
|
||||
Layout::TileShape::kContiguous * Layout::kElementsPerAccess / 2) {
|
||||
if (tile_offset.contiguous() % 2) {
|
||||
// Matrix multiply 1688 pointer_[0] <=> pointer_[4] pointer_[1] <=> pointer_[5]
|
||||
// pointer_[2] <=> pointer_[6] pointer_[3] <=> pointer_[7]
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kPointerCount / 2; ++i) {
|
||||
AccessType const *tmp_pointer = pointer_[i];
|
||||
@ -1535,6 +1550,14 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
access_strided_idx =
|
||||
(lane_in_quad_pair + (lane_id >> 4 << 3)) / Layout::kFactor;
|
||||
}
|
||||
else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) {
|
||||
// Matrix multiply 16832.SP B
|
||||
// Q0 Q1 Q2 Q3
|
||||
partition_contiguous_idx = (lane_id % Layout::kFactor);
|
||||
access_contiguous_idx =
|
||||
(quad_pair ^ (lane_in_quad_pair / Layout::kFactor));
|
||||
access_strided_idx = lane_in_quad_pair / Layout::kFactor;
|
||||
}
|
||||
} else if (Layout::kFactor == 1) {
|
||||
// Super Matrix multiply kBlock = 64
|
||||
if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) {
|
||||
@ -1565,6 +1588,13 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
access_contiguous_idx = ((quad_pair & 1) ^ lane_in_quad);
|
||||
access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3);
|
||||
}
|
||||
else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) {
|
||||
// Matrix multiply 16832.SP B
|
||||
// Q0 Q1 Q2 Q3
|
||||
partition_contiguous_idx = (lane_in_quad_pair >> 2);
|
||||
access_contiguous_idx = (quad_pair ^ lane_in_quad);
|
||||
access_strided_idx = lane_in_quad_pair;
|
||||
}
|
||||
}
|
||||
|
||||
int access_contiguous =
|
||||
@ -2369,17 +2399,18 @@ class MmaTensorOpAccumulatorTileIterator<
|
||||
|
||||
/// Internal structure of iterator - made public to enable introspection
|
||||
struct Policy {
|
||||
static_assert(
|
||||
static bool const kDivisible =
|
||||
!(Shape::kRow % InstructionShape::kM) &&
|
||||
!(Shape::kColumn % InstructionShape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
!(Shape::kColumn % InstructionShape::kN);
|
||||
|
||||
static_assert(platform::is_same<TensorCoord, MatrixCoord>::value,
|
||||
"Layouts must be defined for logical MatrixCoord coordinate space.");
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
|
||||
Shape::kColumn / InstructionShape::kN>;
|
||||
using MmaIterations = MatrixShape<
|
||||
(Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM,
|
||||
(Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN
|
||||
>;
|
||||
};
|
||||
|
||||
private:
|
||||
@ -2398,7 +2429,9 @@ public:
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
using Fragment = Array<
|
||||
Element,
|
||||
Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>;
|
||||
|
||||
private:
|
||||
|
||||
@ -2667,17 +2700,18 @@ class MmaTensorOpAccumulatorTileIterator<Shape_, Element_,
|
||||
|
||||
/// Internal structure of iterator - made public to enable introspection
|
||||
struct Policy {
|
||||
static_assert(
|
||||
static bool const kDivisible =
|
||||
!(Shape::kRow % InstructionShape::kM) &&
|
||||
!(Shape::kColumn % InstructionShape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
!(Shape::kColumn % InstructionShape::kN);
|
||||
|
||||
static_assert(platform::is_same<TensorCoord, MatrixCoord>::value,
|
||||
"Layouts must be defined for logical MatrixCoord coordinate space.");
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<Shape::kRow / InstructionShape::kM,
|
||||
Shape::kColumn / InstructionShape::kN>;
|
||||
using MmaIterations = MatrixShape<
|
||||
(Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM,
|
||||
(Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN
|
||||
>;
|
||||
};
|
||||
|
||||
private:
|
||||
@ -2696,7 +2730,8 @@ public:
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
using Fragment = Array<Element,
|
||||
Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>;
|
||||
|
||||
private:
|
||||
|
||||
|
||||
@ -1570,6 +1570,839 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Tile iterator specialized for canonical matrix layouts
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Operand identity
|
||||
Operand Operand_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
/// Layout of operand
|
||||
typename Layout_,
|
||||
/// Shape of one matrix production operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||
/// MatrixShape)
|
||||
int OpDelta_,
|
||||
/// Number of threads participating in one matrix operation
|
||||
int Threads = 32,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_ = 1>
|
||||
class MmaTensorOpMultiplicandTileIteratorCanonical {
|
||||
public:
|
||||
|
||||
/// Shape of tile to load (concept: MatrixShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Operand tag
|
||||
static Operand const kOperand = Operand_;
|
||||
|
||||
/// Basic check
|
||||
static_assert(kOperand == Operand::kA || kOperand== Operand::kB,
|
||||
"MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma.");
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = Layout_;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
|
||||
static int const kOpDelta = OpDelta_;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Number of elements accessed per Shared Memory load
|
||||
static int const kElementsPerAccess =
|
||||
(sizeof_bits<Element>::value >= 32 ? 1 : 32 / sizeof_bits<Element>::value);
|
||||
|
||||
private:
|
||||
|
||||
static int const kWarpShapeOuter =
|
||||
(kOperand == Operand::kA ? Shape::kRow : Shape::kColumn);
|
||||
|
||||
static int const kWarpShapeInner =
|
||||
(kOperand == Operand::kA ? Shape::kColumn : Shape::kRow);
|
||||
|
||||
|
||||
/// Rounded up instruction counts
|
||||
using InstructionCount = MatrixShape<
|
||||
Shape::kRow / InstructionShape::kRow,
|
||||
Shape::kColumn / InstructionShape::kColumn
|
||||
>;
|
||||
|
||||
/// Rounded up tile dimensions
|
||||
using WarpShapeDivisible = MatrixShape<
|
||||
InstructionCount::kRow * InstructionShape::kRow,
|
||||
InstructionCount::kColumn * InstructionShape::kColumn
|
||||
>;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<
|
||||
Element,
|
||||
WarpShapeDivisible::kRow * WarpShapeDivisible::kColumn / kThreads
|
||||
>;
|
||||
|
||||
/// Memory access type
|
||||
using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
|
||||
/// Underlying tensor reference
|
||||
TensorRef ref_;
|
||||
|
||||
/// Extent of tensor
|
||||
MatrixCoord extent_;
|
||||
|
||||
/// Origin
|
||||
MatrixCoord origin_;
|
||||
|
||||
/// Used to conditionally enable extents checking
|
||||
bool divisible_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIteratorCanonical(): divisible_(true) { }
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIteratorCanonical(
|
||||
TensorRef const &ref,
|
||||
int lane_id
|
||||
): ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) {
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess);
|
||||
}
|
||||
else {
|
||||
origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4);
|
||||
}
|
||||
|
||||
ref_.add_coord_offset(origin_);
|
||||
}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIteratorCanonical(
|
||||
TensorRef const &ref,
|
||||
TensorCoord extent,
|
||||
int lane_id
|
||||
): ref_(ref), extent_(extent), divisible_(false) {
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess);
|
||||
}
|
||||
else {
|
||||
origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4);
|
||||
}
|
||||
|
||||
ref_.add_coord_offset(origin_);
|
||||
}
|
||||
|
||||
/// Adds a pointer offset to internal pointer(s) to advance through memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIteratorCanonical &add_pointer_offset(LongIndex offset) {
|
||||
|
||||
ref_.add_pointer_offset(offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIteratorCanonical &add_tile_offset(TensorCoord const &tile_offset) {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
origin_ += coord_offset;
|
||||
|
||||
ref_.add_coord_offset(coord_offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpMultiplicandTileIteratorCanonical & operator++() {
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
add_tile_offset({0, 1});
|
||||
}
|
||||
else {
|
||||
add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIteratorCanonical & operator--() {
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
add_tile_offset({0, -1});
|
||||
}
|
||||
else {
|
||||
add_tile_offset({-1, 0});
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpMultiplicandTileIteratorCanonical & operator+=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpMultiplicandTileIteratorCanonical & operator-=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(-tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index pointer_offset) const {
|
||||
|
||||
int const kWarpShapeDivisibleInner =
|
||||
(kOperand == Operand::kA ? WarpShapeDivisible::kColumn : WarpShapeDivisible::kRow);
|
||||
|
||||
// Take advantage of Tensor Op's 8 x 4T access pattern
|
||||
int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
|
||||
|
||||
AccessType *access_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
int const kTilesPerInstruction = InstructionShape::kRow / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) {
|
||||
int access_idx =
|
||||
access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx);
|
||||
|
||||
MatrixCoord offset(
|
||||
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
||||
inner_idx * 4 * kElementsPerAccess);
|
||||
|
||||
MatrixCoord access_coord = origin_ + offset;
|
||||
|
||||
if (divisible_ ||
|
||||
(access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) {
|
||||
|
||||
access_ptr[access_idx] = *reinterpret_cast<AccessType const *>(
|
||||
ref_.data() + ref_.offset(offset));
|
||||
}
|
||||
else {
|
||||
AccessType zero;
|
||||
zero.clear();
|
||||
access_ptr[access_idx] = zero;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
||||
int access_idx = inner_idx + kAccessesInner * inst_n_idx;
|
||||
|
||||
MatrixCoord offset(
|
||||
inner_idx * 4 * kElementsPerAccess,
|
||||
inst_n_idx * 8);
|
||||
|
||||
MatrixCoord access_coord = origin_ + offset;
|
||||
|
||||
if (divisible_ ||
|
||||
(access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) {
|
||||
|
||||
access_ptr[access_idx] = *reinterpret_cast<AccessType const *>(
|
||||
ref_.data() + ref_.offset(offset));
|
||||
}
|
||||
else {
|
||||
AccessType zero;
|
||||
zero.clear();
|
||||
access_ptr[access_idx] = zero;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index byte_offset) const {
|
||||
|
||||
load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits<Element>::value);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset) const {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
|
||||
load_with_pointer_offset(frag, ref_.offset(coord_offset));
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index pointer_offset) const {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
|
||||
load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index byte_offset) const {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
|
||||
load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits<Element>::value);
|
||||
}
|
||||
|
||||
/// Notify the iterator which k-group it is currently pointing to.
|
||||
///
|
||||
/// This does not advance the iterator. Rather, it overrides its internal
|
||||
/// tracking with constant-valued k-group index to enable the compiler to
|
||||
/// fold constants and achieve more efficient code.
|
||||
///
|
||||
/// This is used by some nontrivial permuted layouts.
|
||||
CUTLASS_DEVICE
|
||||
void set_kgroup_index(int k_group) {
|
||||
// no operation
|
||||
}
|
||||
};
|
||||
|
||||
/// Wrapper for ColumnMajor
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Identifies A or B multiplicand
|
||||
Operand Operand_,
|
||||
/// Data type of elements
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Interval between adjacent *MMA instructions (in units of MMA
|
||||
/// instructions)
|
||||
int OpDelta_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_>
|
||||
class MmaTensorOpMultiplicandTileIterator<
|
||||
Shape_, Operand_, Element_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
InstructionShape_, OpDelta_, 32, PartitionsK_> {
|
||||
public:
|
||||
|
||||
/// Shape of tile to load (concept: PitchLinearShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Operand tag
|
||||
static Operand const kOperand = Operand_;
|
||||
|
||||
static_assert(kOperand == Operand::kA || kOperand== Operand::kB,
|
||||
"MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma.");
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = cutlass::layout::ColumnMajor;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
|
||||
static int const kOpDelta = OpDelta_;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Underlying tile iterator implementation
|
||||
using Base = MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
Shape, kOperand, Element,
|
||||
layout::ColumnMajor,
|
||||
InstructionShape,
|
||||
kOpDelta, kThreads, PartitionsK_>;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
private:
|
||||
|
||||
/// Underlying tile iterator
|
||||
Base iterator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator() { }
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator(
|
||||
TensorRef const &ref,
|
||||
int lane_id
|
||||
): iterator_({ref.data(), ref.stride()}, lane_id) {
|
||||
}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator(
|
||||
TensorRef const &ref,
|
||||
TensorCoord const & extent,
|
||||
int lane_id
|
||||
): iterator_({ref.data(), ref.stride()}, extent, lane_id) {
|
||||
}
|
||||
|
||||
/// Adds a pointer offset to internal pointer(s) to advance through memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) {
|
||||
|
||||
iterator_.add_pointer_offset(offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) {
|
||||
|
||||
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator & operator++() {
|
||||
|
||||
++iterator_;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator & operator--() {
|
||||
|
||||
--iterator_;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column()));
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column()));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
iterator_.load(frag);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index pointer_offset) const {
|
||||
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index byte_offset) const {
|
||||
iterator_.load_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset) const {
|
||||
// TODO
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index pointer_offset) const {
|
||||
// TODO
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index byte_offset) const {
|
||||
iterator_.load_with_byte_offset(
|
||||
frag,
|
||||
{tile_offset.contiguous(), tile_offset.strided()},
|
||||
byte_offset);
|
||||
}
|
||||
|
||||
/// Notify the iterator which k-group it is currently pointing to.
|
||||
///
|
||||
/// This does not advance the iterator. Rather, it overrides its internal
|
||||
/// tracking with constant-valued k-group index to enable the compiler to
|
||||
/// fold constants and achieve more efficient code.
|
||||
///
|
||||
/// This is used by some nontrivial permuted layouts.
|
||||
CUTLASS_DEVICE
|
||||
void set_kgroup_index(int k_group) {
|
||||
iterator_.set_kgroup_index(k_group);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Wrapper for RowMajor
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Identifies A or B multiplicand
|
||||
Operand Operand_,
|
||||
/// Data type of elements
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Interval between adjacent *MMA instructions (in units of MMA
|
||||
/// instructions)
|
||||
int OpDelta_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_>
|
||||
class MmaTensorOpMultiplicandTileIterator<
|
||||
Shape_, Operand_, Element_,
|
||||
cutlass::layout::RowMajor,
|
||||
InstructionShape_, OpDelta_, 32, PartitionsK_> {
|
||||
public:
|
||||
|
||||
/// Shape of tile to load (concept: PitchLinearShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Operand tag
|
||||
static Operand const kOperand = Operand_;
|
||||
|
||||
static_assert(kOperand == Operand::kA || kOperand== Operand::kB,
|
||||
"MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma.");
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
|
||||
static int const kOpDelta = OpDelta_;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Underlying tile iterator implementation
|
||||
using Base = MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
Shape, kOperand, Element,
|
||||
layout::RowMajor,
|
||||
InstructionShape,
|
||||
kOpDelta, kThreads, PartitionsK_>;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
private:
|
||||
|
||||
/// Underlying tile iterator
|
||||
Base iterator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator() { }
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator(
|
||||
TensorRef const &ref,
|
||||
int lane_id
|
||||
): iterator_({ref.data(), ref.stride()}, lane_id) {
|
||||
}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator(
|
||||
TensorRef const &ref,
|
||||
TensorCoord const &extent,
|
||||
int lane_id
|
||||
): iterator_({ref.data(), ref.stride()}, extent, lane_id) {
|
||||
}
|
||||
|
||||
/// Adds a pointer offset to internal pointer(s) to advance through memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) {
|
||||
|
||||
iterator_.add_pointer_offset(offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) {
|
||||
|
||||
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator & operator++() {
|
||||
|
||||
++iterator_;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator & operator--() {
|
||||
|
||||
--iterator_;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column()));
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column()));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
iterator_.load(frag);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index pointer_offset) const {
|
||||
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index byte_offset) const {
|
||||
iterator_.load_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset) const {
|
||||
// TODO
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index pointer_offset) const {
|
||||
// TODO
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index byte_offset) const {
|
||||
iterator_.load_with_byte_offset(
|
||||
frag,
|
||||
{tile_offset.contiguous(), tile_offset.strided()},
|
||||
byte_offset);
|
||||
}
|
||||
|
||||
/// Notify the iterator which k-group it is currently pointing to.
|
||||
///
|
||||
/// This does not advance the iterator. Rather, it overrides its internal
|
||||
/// tracking with constant-valued k-group index to enable the compiler to
|
||||
/// fold constants and achieve more efficient code.
|
||||
///
|
||||
/// This is used by some nontrivial permuted layouts.
|
||||
CUTLASS_DEVICE
|
||||
void set_kgroup_index(int k_group) {
|
||||
iterator_.set_kgroup_index(k_group);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
|
||||
374
include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h
Normal file
374
include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h
Normal file
@ -0,0 +1,374 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators to load sparse meta data used by warp-level matrix multiply operations
|
||||
targeting Sparse Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
|
||||
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
/// Layout of operand
|
||||
typename Layout_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||
/// MatrixShape)
|
||||
int OpDelta_,
|
||||
/// Number of threads participating in one matrix operation
|
||||
int Threads,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_ = 1>
|
||||
class SparseMmaTensorOpMetaTileIterator {
|
||||
public:
|
||||
/// Shape of tile to load (concept: PitchLinearShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = Layout_;
|
||||
|
||||
/// Shape of one matrix product operation (concept: GemmShape)
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||
/// MatrixShape)
|
||||
static int const kOpDelta = OpDelta_;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// Number of partitions along K dimension
|
||||
static int const kPartitionsK = PartitionsK_;
|
||||
|
||||
static int const kSparse = 2;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Internal structure of iterator - made public to enable introspection
|
||||
struct Policy {
|
||||
static_assert(
|
||||
!(Shape::kColumn % InstructionShape::kColumn),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
static int const kElementsPerAccess = 128 / sizeof_bits<Element>::value;
|
||||
|
||||
// Determine number of elements along outer dimension per individual LDSM op
|
||||
static int const kLdsmOpOuter = InstructionShape::kColumn;
|
||||
static int const kLdsmOpInner = 8 * kElementsPerAccess / kLdsmOpOuter;
|
||||
|
||||
static_assert(!(Shape::kColumn % kLdsmOpOuter),
|
||||
"Shape of warp-level mma must be divisible by LDSM's "
|
||||
"fundamental tile size.");
|
||||
|
||||
static_assert(!(Shape::kRow % kLdsmOpInner),
|
||||
"Shape of warp-level mma must be divisible by LDSM's "
|
||||
"fundamental tile size.");
|
||||
|
||||
/// Shape of one individual LDSM instruction
|
||||
static int const LdsmShapeColumn =
|
||||
InstructionShape::kColumn / kLdsmOpOuter;
|
||||
static int const LdsmShapeRow =
|
||||
((4 / LdsmShapeColumn * kLdsmOpInner) > Shape::kRow)
|
||||
? (Shape::kRow / kLdsmOpInner)
|
||||
: (4 / LdsmShapeColumn);
|
||||
using LdsmShape =
|
||||
layout::PitchLinearShape<LdsmShapeRow, LdsmShapeColumn>;
|
||||
|
||||
/// Number and arrangement of LDSM instructions
|
||||
using LdsmIterations = layout::PitchLinearShape<
|
||||
Shape::kRow / kLdsmOpInner / LdsmShapeRow,
|
||||
1>;
|
||||
|
||||
/// Number of groups for each tile
|
||||
static int const kGroupsPerTile =
|
||||
Shape::kColumn / InstructionShape::kColumn;
|
||||
};
|
||||
|
||||
private:
|
||||
/// Not working on this feature at the moment.
|
||||
static_assert(kOpDelta == 1,
|
||||
"Alternative arrangements not supported at present.");
|
||||
|
||||
/// Pointer type used for accesses
|
||||
using AccessType = Array<Element, Policy::kElementsPerAccess>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment =
|
||||
Array<Element, Shape::kRow * InstructionShape::kColumn / kThreads>;
|
||||
|
||||
private:
|
||||
|
||||
/// Layout object storing stride values
|
||||
Index stride_;
|
||||
|
||||
/// Shared memory base pointers - not advanced
|
||||
AccessType const *pointer_;
|
||||
|
||||
/// Byte offset incremented as iterator advances
|
||||
Index byte_offset_;
|
||||
|
||||
/// Internal counter used to determine when to increment byte offset and when
|
||||
/// to XOR it
|
||||
int k_group_idx_;
|
||||
|
||||
public:
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
SparseMmaTensorOpMetaTileIterator()
|
||||
: pointer_(nullptr),
|
||||
stride_(0),
|
||||
byte_offset_(0),
|
||||
k_group_idx_(0) {}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_DEVICE
|
||||
SparseMmaTensorOpMetaTileIterator(TensorRef const &ref, int lane_id)
|
||||
: pointer_(reinterpret_cast<AccessType const *>(ref.data())),
|
||||
stride_(ref.stride(0) / Policy::kElementsPerAccess),
|
||||
byte_offset_(0),
|
||||
k_group_idx_(0) {
|
||||
|
||||
int access_contiguous = (lane_id % (Shape::kRow / Policy::kElementsPerAccess));
|
||||
int access_strided = (lane_id / (Shape::kRow / Policy::kElementsPerAccess));
|
||||
|
||||
byte_offset_ = (access_contiguous + access_strided * stride_) *
|
||||
sizeof_bits<Element>::value * Policy::kElementsPerAccess / 8;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset to internal pointer(s) to advance through memory
|
||||
CUTLASS_DEVICE
|
||||
SparseMmaTensorOpMetaTileIterator &add_pointer_offset(LongIndex offset) {
|
||||
byte_offset_ += offset * sizeof_bits<Element>::value / 8;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole
|
||||
/// tiles
|
||||
CUTLASS_DEVICE
|
||||
SparseMmaTensorOpMetaTileIterator &add_tile_offset(
|
||||
TensorCoord const &tile_offset) {
|
||||
int offset = tile_offset.row() * Shape::kRow +
|
||||
tile_offset.column() * InstructionShape::kColumn * stride_ *
|
||||
Policy::kElementsPerAccess;
|
||||
|
||||
add_pointer_offset(offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_DEVICE
|
||||
SparseMmaTensorOpMetaTileIterator &operator++() {
|
||||
add_tile_offset({0, 1});
|
||||
|
||||
if (kPartitionsK > 1) {
|
||||
++k_group_idx_;
|
||||
// Jump to next stage
|
||||
if (k_group_idx_ == Policy::kGroupsPerTile) {
|
||||
k_group_idx_ = 0;
|
||||
add_tile_offset(
|
||||
{0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)});
|
||||
}
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
SparseMmaTensorOpMetaTileIterator &operator--(){
|
||||
byte_offset_ -= stride_ * InstructionShape::kColumn *
|
||||
sizeof_bits<Element>::value * Policy::kElementsPerAccess /
|
||||
8;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of
|
||||
///< the tensor
|
||||
CUTLASS_DEVICE SparseMmaTensorOpMetaTileIterator &
|
||||
operator+=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of
|
||||
///< the tensor
|
||||
CUTLASS_DEVICE
|
||||
SparseMmaTensorOpMetaTileIterator &operator-=(
|
||||
TensorCoord const &tile_offset) {
|
||||
add_tile_offset(-tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const { load_with_byte_offset(frag, 0); }
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset in units of bytes
|
||||
Index byte_offset) const {
|
||||
Array<unsigned, Policy::LdsmShape::kCount> *fetch_ptr =
|
||||
reinterpret_cast<Array<unsigned, Policy::LdsmShape::kCount> *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) {
|
||||
|
||||
int access_idx = c + s * Policy::LdsmIterations::kContiguous;
|
||||
|
||||
AccessType const *source_ptr =
|
||||
pointer_ +
|
||||
Policy::LdsmShape::kContiguous * Policy::kLdsmOpInner * c +
|
||||
Policy::LdsmShape::kStrided * s * stride_;
|
||||
|
||||
char const *source_byte_ptr = reinterpret_cast<char const *>(source_ptr) +
|
||||
byte_offset + byte_offset_;
|
||||
|
||||
cutlass::arch::ldsm<layout::RowMajor, Policy::LdsmShape::kCount>(
|
||||
fetch_ptr[access_idx], source_byte_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index pointer_offset) const {
|
||||
load_with_byte_offset(frag, pointer_offset * sizeof(Element));
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset) const {
|
||||
load_with_byte_offset(frag, tile_offset, 0);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index pointer_offset) const {
|
||||
load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element));
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index byte_offset) const {
|
||||
Index pointer_offset =
|
||||
tile_offset.contiguous() * Shape::kRow / Layout::kElementsPerAccess +
|
||||
tile_offset.strided() * InstructionShape::kColumn * stride_;
|
||||
|
||||
byte_offset += sizeof(AccessType) * pointer_offset;
|
||||
|
||||
load_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Notify the iterator which k-group it is currently pointing to.
|
||||
///
|
||||
/// This does not advance the iterator. Rather, it overrides its internal
|
||||
/// tracking with constant-valued k-group index to enable the compiler to
|
||||
/// fold constants and achieve more efficient code.
|
||||
///
|
||||
/// This is used by some nontrivial permuted layouts.
|
||||
CUTLASS_DEVICE
|
||||
void set_kgroup_index(int k_group) {
|
||||
// no op
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -349,7 +349,7 @@ struct alignas(2) half_t {
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
half_t() { }
|
||||
half_t() : storage(0) { }
|
||||
|
||||
/// Reinterpret cast from CUDA's half type
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -83,11 +83,10 @@ struct integer_subbyte {
|
||||
integer_subbyte(unsigned value)
|
||||
: storage(reinterpret_cast<Storage const &>(value) & kMask) {}
|
||||
|
||||
/// Conversion from double
|
||||
CUTLASS_HOST_DEVICE
|
||||
integer_subbyte(double value) {
|
||||
T tmp = (T)value;
|
||||
storage = reinterpret_cast<Storage const &>(tmp) & kMask;
|
||||
T tmp = static_cast<T>(value);
|
||||
storage = Storage(reinterpret_cast<unsigned const &>(tmp) & kMask);
|
||||
}
|
||||
|
||||
///
|
||||
@ -155,6 +154,12 @@ struct integer_subbyte {
|
||||
/// 1-bit Unsigned integer type
|
||||
using uint1b_t = integer_subbyte<1, false>;
|
||||
|
||||
/// 2-bit Integer type
|
||||
using int2b_t = integer_subbyte<2, true>;
|
||||
|
||||
/// 2-bit Unsigned integer type
|
||||
using uint2b_t = integer_subbyte<2, false>;
|
||||
|
||||
/// 4-bit Integer type
|
||||
using int4b_t = integer_subbyte<4, true>;
|
||||
|
||||
@ -169,6 +174,18 @@ struct sizeof_bits<uint1b_t> {
|
||||
static int const value = 1;
|
||||
};
|
||||
|
||||
/// Defines the size of an element in bits - specialized for int2b_t
|
||||
template <>
|
||||
struct sizeof_bits<int2b_t> {
|
||||
static int const value = 2;
|
||||
};
|
||||
|
||||
/// Defines the size of an element in bits - specialized for uint2b_t
|
||||
template <>
|
||||
struct sizeof_bits<uint2b_t> {
|
||||
static int const value = 2;
|
||||
};
|
||||
|
||||
/// Defines the size of an element in bits - specialized for int4b_t
|
||||
template <>
|
||||
struct sizeof_bits<int4b_t> {
|
||||
|
||||
@ -35,7 +35,6 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace layout {
|
||||
@ -803,7 +802,7 @@ private:
|
||||
// Data members
|
||||
//
|
||||
|
||||
MatrixLayout layout_id_;
|
||||
Matrix layout_id_;
|
||||
|
||||
/// Stride data member
|
||||
Stride stride_;
|
||||
@ -815,12 +814,12 @@ public:
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
GeneralMatrix(): layout_id_(MatrixLayout::kColumnMajor), stride_(make_Coord(0, 1)) { }
|
||||
GeneralMatrix(): layout_id_(Matrix::kColumnMajor), stride_(make_Coord(0, 1)) { }
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
GeneralMatrix(
|
||||
MatrixLayout layout_id,
|
||||
Matrix layout_id,
|
||||
Index ldm,
|
||||
Index interleave): layout_id_(layout_id), stride_(make_Coord(ldm, interleave)) { }
|
||||
|
||||
@ -828,11 +827,11 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
static GeneralMatrix packed(
|
||||
MatrixCoord const &extent,
|
||||
MatrixLayout layout_id = MatrixLayout::kColumnMajor,
|
||||
Matrix layout_id = Matrix::kColumnMajor,
|
||||
Index interleave = 1) {
|
||||
|
||||
Index c;
|
||||
if (layout_id == MatrixLayout::kRowMajor) {
|
||||
if (layout_id == Matrix::kRowMajor) {
|
||||
c = extent.column();
|
||||
}
|
||||
else {
|
||||
@ -849,7 +848,7 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(MatrixCoord const &coord) const {
|
||||
Index c, s;
|
||||
if (layout_id_ == MatrixLayout::kRowMajor) {
|
||||
if (layout_id_ == Matrix::kRowMajor) {
|
||||
c = coord.column();
|
||||
s = coord.row();
|
||||
}
|
||||
@ -871,7 +870,7 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixLayout layout_id() const {
|
||||
Matrix layout_id() const {
|
||||
return layout_id_;
|
||||
}
|
||||
|
||||
@ -882,7 +881,7 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixLayout & layout_id() {
|
||||
Matrix & layout_id() {
|
||||
return layout_id_;
|
||||
}
|
||||
|
||||
@ -902,7 +901,7 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex capacity(MatrixCoord const &extent) const {
|
||||
Index s;
|
||||
if (layout_id_ == MatrixLayout::kRowMajor) {
|
||||
if (layout_id_ == Matrix::kRowMajor) {
|
||||
s = extent.row();
|
||||
}
|
||||
else {
|
||||
|
||||
@ -79,7 +79,7 @@ private:
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Stride data member - [c, wc, hwc]
|
||||
/// Stride data member - [stride_w, stride_h, stride_n]
|
||||
Stride stride_;
|
||||
|
||||
public:
|
||||
@ -93,7 +93,12 @@ public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorNHWC(typename Stride::Index c, typename Stride::Index wc, typename Stride::Index hwc): stride_(make_Coord(c, wc, hwc)) { }
|
||||
TensorNHWC(
|
||||
typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates
|
||||
typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates
|
||||
typename Stride::Index stride_n ///< number of elements between adjacent N coordinates
|
||||
):
|
||||
stride_(make_Coord(stride_w, stride_h, stride_n)) { }
|
||||
|
||||
/// Helper returns a layout to a tightly packed NHWC tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -116,12 +121,6 @@ public:
|
||||
LongIndex(stride_[2] * coord.n());
|
||||
}
|
||||
|
||||
/// Returns a RowMajor equivalent for a TensorNHWC layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit operator RowMajor() {
|
||||
return RowMajor(stride_[0]);
|
||||
}
|
||||
|
||||
/// Returns the logical coordinate (n, h, w, c) from a given offset in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord inverse(LongIndex index) const {
|
||||
@ -444,6 +443,107 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Mapping function for 5-D NDHWC tensors.
|
||||
class TensorNDHWC {
|
||||
public:
|
||||
/// Logical rank of tensor
|
||||
static int const kRank = 5;
|
||||
|
||||
/// Rank of stride vector
|
||||
static int const kStrideRank = 4;
|
||||
|
||||
/// Index type used for coordinates
|
||||
using Index = int32_t;
|
||||
|
||||
/// Long index type used for offsets
|
||||
using LongIndex = int64_t;
|
||||
|
||||
/// Logical coordinate (n, d, h, w, c)
|
||||
using TensorCoord = Tensor5DCoord;
|
||||
|
||||
/// Stride vector
|
||||
using Stride = Coord<kStrideRank>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Stride data member - [c, wc, hwc, dhwc]
|
||||
Stride stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorNDHWC(Stride const &stride = Stride(0)): stride_(stride) { }
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorNDHWC(
|
||||
typename Stride::Index c,
|
||||
typename Stride::Index wc,
|
||||
typename Stride::Index hwc,
|
||||
typename Stride::Index dhwc):
|
||||
stride_(make_Coord(c, wc, hwc, dhwc)) { }
|
||||
|
||||
/// Helper returns a layout to a tightly packed NHWC tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static TensorNDHWC packed(TensorCoord const &extent) {
|
||||
return TensorNDHWC(
|
||||
make_Coord(
|
||||
extent.c(),
|
||||
extent.w() * extent.c(),
|
||||
extent.h() * extent.w() * extent.c(),
|
||||
extent.d() * extent.h() * extent.w() * extent.c()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Returns the offset of a coordinate (n, d, h, w, c) in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(TensorCoord const &coord) const {
|
||||
return coord.c() +
|
||||
LongIndex(stride_[0] * coord.w()) +
|
||||
LongIndex(stride_[1] * coord.h()) +
|
||||
LongIndex(stride_[2] * coord.d()) +
|
||||
LongIndex(stride_[3] * coord.n());
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride & stride() {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
/// Compute the number of contiguous elements needed to store a tensor with the given size
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex capacity(TensorCoord const &extent) const {
|
||||
// it does not make sense if the extent is larger than stride
|
||||
// and we could not rely on the capacity calculation in such cases
|
||||
// we could move this checkers to debug code only
|
||||
if ((extent.c() > stride_[0])
|
||||
|| (extent.w() * stride_[0] > stride_[1])
|
||||
|| (extent.h() * stride_[1] > stride_[2])
|
||||
|| (extent.d() * stride_[2] > stride_[3])) {
|
||||
assert(0);
|
||||
}
|
||||
return extent.n() * stride_[3];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace layout
|
||||
|
||||
@ -81,17 +81,23 @@ struct TensorOpMultiplicand {
|
||||
static int const kFactor =
|
||||
kTileShapeContiguous * kElementsPerAccess / kCrosswise;
|
||||
|
||||
/// The strided dimension needs to be at least WarpSize(32) /
|
||||
/// kTileShapeContiguous for a warp to access. To ensure conflict free
|
||||
static_assert(
|
||||
(kFactor > 0),
|
||||
"kCrosswise should be no large than one shared memory cache line.");
|
||||
|
||||
/// The strided dimension needs to be at least (WarpSize(32) /
|
||||
/// kTileShapeContiguous) for a warp to access. To ensure conflict free
|
||||
/// access, it also needs to be at least (kTileShapeContiguous / kFactor).
|
||||
/// See comments below
|
||||
static int const kTileShapeStride =
|
||||
((kTileShapeContiguous / kFactor) > (32 / kTileShapeContiguous))
|
||||
? (kTileShapeContiguous / kFactor)
|
||||
: (32 / kTileShapeContiguous);
|
||||
|
||||
/// Fundamental tile shape in units of vectors
|
||||
/// For TN kblock=32 and 8x8x16 shapes, TileShape = <8, 4>.
|
||||
/// For the rest, TileShape = <8, 8>
|
||||
/// Fundamental tile shape in units of vectors to guarantee bank conflict free
|
||||
/// shared memory load/store.
|
||||
/// For kFactor = 1, TileShape = <8, 8>
|
||||
/// For kFactor > 1, TileShape = <8, 4>
|
||||
using TileShape = PitchLinearShape<kTileShapeContiguous, kTileShapeStride>;
|
||||
|
||||
/// Fundamental partition shape in units of vectors
|
||||
|
||||
14111
include/cutlass/matrix.h
Normal file
14111
include/cutlass/matrix.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -515,17 +515,30 @@ struct NumericConverterClamp<T, float> {
|
||||
using source_type = float;
|
||||
|
||||
static_assert((platform::is_same<result_type, int32_t>::value ||
|
||||
platform::is_same<result_type, int16_t>::value ||
|
||||
platform::is_same<result_type, uint16_t>::value ||
|
||||
platform::is_same<result_type, int8_t>::value ||
|
||||
platform::is_same<result_type, cutlass::int4b_t>::value),
|
||||
platform::is_same<result_type, uint8_t>::value ||
|
||||
platform::is_same<result_type, cutlass::int4b_t>::value ||
|
||||
platform::is_same<result_type, cutlass::uint4b_t>::value),
|
||||
"Clamp is only needed for integer types");
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & s) {
|
||||
|
||||
NumericConverter<result_type, double> convert_op;
|
||||
double kClamp_max, kClamp_min;
|
||||
|
||||
double kClamp_max = double((1U << (sizeof_bits<result_type>::value - 1)) - 1);
|
||||
double kClamp_min = -kClamp_max - 1;
|
||||
if (platform::is_same<result_type, int32_t>::value ||
|
||||
platform::is_same<result_type, int16_t>::value ||
|
||||
platform::is_same<result_type, int8_t>::value ||
|
||||
platform::is_same<result_type, cutlass::int4b_t>::value) {
|
||||
kClamp_max = double((1LLU << (sizeof_bits<result_type>::value - 1)) - 1);
|
||||
kClamp_min = -kClamp_max - 1;
|
||||
} else {
|
||||
kClamp_max = double((1LLU << (sizeof_bits<result_type>::value)) - 1);
|
||||
kClamp_min = 0;
|
||||
}
|
||||
|
||||
double source = s;
|
||||
|
||||
@ -946,6 +959,130 @@ struct NumericArrayConverter<int8_t, int, N, Round> {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<uint8_t, 1> <= Array<int, 1>
|
||||
template <
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
struct NumericArrayConverter<uint8_t, int, 1, Round> {
|
||||
|
||||
using result_type = Array<uint8_t, 1>;
|
||||
using source_type = Array<int, 1>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & source) {
|
||||
NumericConverter<uint8_t, int, Round> convert_element_;
|
||||
|
||||
result_type result;
|
||||
|
||||
result[0] = convert_element_(source[0]);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<uint8_t, 2> <= Array<int, 2>
|
||||
template <
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
struct NumericArrayConverter<uint8_t, int, 2, Round> {
|
||||
|
||||
using result_type = Array<uint8_t, 2>;
|
||||
using source_type = Array<int, 2>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & source) {
|
||||
|
||||
uint32_t tmp;
|
||||
|
||||
asm volatile(
|
||||
"cvt.pack.sat.u8.s32.b32 %0, %2, %1, 0;\n"
|
||||
: "=r"(tmp) : "r"(source[0]), "r"(source[1]));
|
||||
|
||||
uint16_t out = (tmp & 0xffff);
|
||||
return reinterpret_cast<result_type const &>(out);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<uint8_t, 4> <= Array<int, 4>
|
||||
template <
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
struct NumericArrayConverter<uint8_t, int, 4, Round> {
|
||||
|
||||
using result_type = Array<uint8_t, 4>;
|
||||
using source_type = Array<int, 4>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & source) {
|
||||
|
||||
unsigned out;
|
||||
|
||||
asm volatile(
|
||||
"{ .reg .u32 r4;"
|
||||
"cvt.pack.sat.u8.s32.b32 r4, %4, %3, 0;"
|
||||
"cvt.pack.sat.u8.s32.b32 %0, %2, %1, r4;"
|
||||
"}"
|
||||
: "=r"(out) : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]));
|
||||
|
||||
return reinterpret_cast<result_type const &>(out);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<int8_t> <= Array<int>
|
||||
template <
|
||||
int N,
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
struct NumericArrayConverter<uint8_t, int, N, Round> {
|
||||
static_assert(!(N % 4), "N must be multiple of 4.");
|
||||
|
||||
using result_type = Array<uint8_t, N>;
|
||||
using source_type = Array<int, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & source) {
|
||||
|
||||
NumericArrayConverter<uint8_t, int, 4, Round> convert_vector_;
|
||||
|
||||
result_type result;
|
||||
|
||||
Array<uint8_t, 4> *result_ptr = reinterpret_cast<Array<uint8_t, 4> *>(&result);
|
||||
Array<int, 4> const *source_ptr = reinterpret_cast<Array<int, 4> const *>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 4; ++i) {
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1025,12 +1162,84 @@ struct NumericArrayConverter<int4b_t, int, N, Round> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<uint4b_t, 8> <= Array<int, 8>
|
||||
template <
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
struct NumericArrayConverter<uint4b_t, int, 8, Round> {
|
||||
|
||||
using result_type = Array<uint4b_t, 8>;
|
||||
using source_type = Array<int, 8>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & source) {
|
||||
|
||||
unsigned out;
|
||||
|
||||
asm volatile(
|
||||
"{ .reg .u32 r4;"
|
||||
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
|
||||
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
|
||||
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
|
||||
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
|
||||
"}"
|
||||
: "=r"(out)
|
||||
: "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]),
|
||||
"r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7]));
|
||||
|
||||
return reinterpret_cast<result_type const &>(out);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<int4b_t> <= Array<int>
|
||||
template <
|
||||
int N,
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
struct NumericArrayConverter<uint4b_t, int, N, Round> {
|
||||
static_assert(!(N % 8), "N must be multiple of 8.");
|
||||
|
||||
using result_type = Array<uint4b_t, N>;
|
||||
using source_type = Array<int, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & source) {
|
||||
|
||||
NumericArrayConverter<uint4b_t, int, 8, Round> convert_vector_;
|
||||
|
||||
result_type result;
|
||||
|
||||
Array<uint4b_t, 8> *result_ptr = reinterpret_cast<Array<uint4b_t, 8> *>(&result);
|
||||
Array<int, 8> const *source_ptr = reinterpret_cast<Array<int, 8> const *>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 8; ++i) {
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // Conditional guards to enable partial specialization for packed integers
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// FastNumericArrayConverter only works when the source is within center range.
|
||||
/// Conversion operator for Array
|
||||
/// Conversion operator for Array. See the comments before
|
||||
/// FastLinearCombinationClamp.
|
||||
template <typename T, typename S, int N,
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
|
||||
struct FastNumericArrayConverter {
|
||||
|
||||
616
include/cutlass/quaternion.h
Normal file
616
include/cutlass/quaternion.h
Normal file
@ -0,0 +1,616 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a densely packed quaternion object intended for storing data in registers and
|
||||
executing quaternion operations within a CUDA or host thread.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Quaternion: xi + yj + zk + w
|
||||
template <
|
||||
typename Element_ = float ///< element type
|
||||
>
|
||||
class Quaternion : public Array<Element_, 4> {
|
||||
public:
|
||||
|
||||
/// Logical rank of tensor index space
|
||||
static int const kRank = 1;
|
||||
|
||||
/// Number of elements
|
||||
static int const kExtent = 4;
|
||||
|
||||
/// Base class is a four-element array
|
||||
using Base = Array<Element_, kExtent>;
|
||||
|
||||
/// Element type
|
||||
using Element = typename Base::Element;
|
||||
|
||||
/// Reference type to an element
|
||||
using Reference = typename Base::reference;
|
||||
|
||||
/// Index type
|
||||
using Index = int;
|
||||
|
||||
/// Quaternion storage - imaginary part
|
||||
static int const kX = 0;
|
||||
|
||||
/// Quaternion storage - imaginary part
|
||||
static int const kY = 1;
|
||||
|
||||
/// Quaternion storage - imaginary part
|
||||
static int const kZ = 2;
|
||||
|
||||
/// Quaternion storage - real part
|
||||
static int const kW = 3;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a quaternion
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion(
|
||||
Element w_ = Element(1)
|
||||
) {
|
||||
Base::at(kX) = Element(0);
|
||||
Base::at(kY) = Element(0);
|
||||
Base::at(kZ) = Element(0);
|
||||
Base::at(kW) = w_;
|
||||
}
|
||||
|
||||
/// Constructs a quaternion
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion(
|
||||
Element x_,
|
||||
Element y_,
|
||||
Element z_,
|
||||
Element w_
|
||||
) {
|
||||
Base::at(kX) = x_;
|
||||
Base::at(kY) = y_;
|
||||
Base::at(kZ) = z_;
|
||||
Base::at(kW) = w_;
|
||||
}
|
||||
|
||||
/// Constructs a quaternion from a vector representing the imaginary part and a real number
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion(
|
||||
Matrix3x1<Element> const &imag_,
|
||||
Element w_ = Element()
|
||||
) {
|
||||
Base::at(kX) = imag_[0];
|
||||
Base::at(kY) = imag_[1];
|
||||
Base::at(kZ) = imag_[2];
|
||||
Base::at(kW) = w_;
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Reference at(Index idx) const {
|
||||
return Base::at(idx);
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Reference at(Index idx) {
|
||||
return Base::at(idx);
|
||||
}
|
||||
|
||||
/// Accesses the x element of the imaginary part of the quaternion
|
||||
CUTLASS_HOST_DEVICE
|
||||
Element x() const {
|
||||
return Base::at(kX);
|
||||
}
|
||||
|
||||
/// Accesses the x element of the imaginary part of the quaternion
|
||||
CUTLASS_HOST_DEVICE
|
||||
Reference x() {
|
||||
return Base::at(kX);
|
||||
}
|
||||
|
||||
/// Accesses the y element of the imaginary part of the quaternion
|
||||
CUTLASS_HOST_DEVICE
|
||||
Element y() const {
|
||||
return Base::at(kY);
|
||||
}
|
||||
|
||||
/// Accesses the y element of the imaginary part of the quaternion
|
||||
CUTLASS_HOST_DEVICE
|
||||
Reference y() {
|
||||
return Base::at(kY);
|
||||
}
|
||||
|
||||
/// Accesses the z element of the imaginary part of the quaternion
|
||||
CUTLASS_HOST_DEVICE
|
||||
Element z() const {
|
||||
return Base::at(kZ);
|
||||
}
|
||||
|
||||
/// Accesses the z element of the imaginary part of the quaternion
|
||||
CUTLASS_HOST_DEVICE
|
||||
Reference z() {
|
||||
return Base::at(kZ);
|
||||
}
|
||||
|
||||
/// Accesses the real part of the quaternion
|
||||
CUTLASS_HOST_DEVICE
|
||||
Element w() const {
|
||||
return Base::at(kW);
|
||||
}
|
||||
|
||||
/// Accesses the real part of the quaternion
|
||||
CUTLASS_HOST_DEVICE
|
||||
Reference w() {
|
||||
return Base::at(kW);
|
||||
}
|
||||
|
||||
/// Returns the pure imaginary part of the quaternion as a 3-vector
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix3x1<Element> pure() const {
|
||||
return Matrix3x1<Element>(x(), y(), z());
|
||||
}
|
||||
|
||||
/// Returns a quaternion representation of a spatial rotation given a unit-length axis and
|
||||
/// a rotation in radians.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Quaternion<Element> rotation(
|
||||
Matrix3x1<Element> const &axis_unit, ///< axis of rotation (assumed to be unit length)
|
||||
Element theta) { ///< angular rotation in radians
|
||||
|
||||
Element s = fast_sin(theta / Element(2));
|
||||
|
||||
return Quaternion(
|
||||
s * axis_unit[0],
|
||||
s * axis_unit[1],
|
||||
s * axis_unit[2],
|
||||
fast_cos(theta / Element(2))
|
||||
);
|
||||
}
|
||||
|
||||
/// Returns a quaternion representation of a spatial rotation represented as a
|
||||
/// unit-length rotation axis (r_x, r_y, r_z) and an angular rotation in radians
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Quaternion<Element> rotation(
|
||||
Element r_x,
|
||||
Element r_y,
|
||||
Element r_z,
|
||||
Element theta) { ///< angular rotation in radians
|
||||
|
||||
return rotation({r_x, r_y, r_z}, theta);
|
||||
}
|
||||
|
||||
/// Geometric rotation of a 3-element vector
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix3x1<Element> rotate(Matrix3x1<Element> const &rhs) const {
|
||||
return (*this * Quaternion<Element>(rhs, 0) * reciprocal(*this)).pure();
|
||||
}
|
||||
|
||||
/// Inverse rotation operation
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix3x1<Element> rotate_inv(Matrix3x1<Element> const &rhs) const {
|
||||
return (reciprocal(*this) * Quaternion<Element>(rhs, 0) * *this).pure();
|
||||
}
|
||||
|
||||
/// Rotates a 3-vector assuming this is a unit quaternion (a spinor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix3x1<Element> spinor(Matrix3x1<Element> const &rhs) const {
|
||||
return (*this * Quaternion<Element>(rhs, 0) * conj(*this)).pure();
|
||||
}
|
||||
|
||||
/// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix3x1<Element> spinor_inv(Matrix3x1<Element> const &rhs) const {
|
||||
return (conj(*this) * Quaternion<Element>(rhs, 0) * *this).pure();
|
||||
}
|
||||
|
||||
/// In-place addition
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> &operator+=(Quaternion<Element> const &rhs) {
|
||||
*this = (*this + rhs);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place subtraction
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> &operator-=(Quaternion<Element> const &rhs) {
|
||||
*this = (*this - rhs);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place multiplication
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> &operator*=(Quaternion<Element> const &rhs) {
|
||||
*this = (*this * rhs);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Scalar multiplication
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> &operator*=(Element s) {
|
||||
*this = (*this * s);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place Division
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> &operator/=(Quaternion<Element> const &rhs) {
|
||||
*this = (*this / rhs);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place Division
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> &operator/=(Element s) {
|
||||
*this = (*this / s);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Computes a 3x3 rotation matrix (row-major representation)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix3x3<Element> as_rotation_matrix_3x3() const {
|
||||
Matrix3x3<Element> m(
|
||||
w() * w() + x() * x() - y() * y() - z() * z(),
|
||||
2 * x() * y() - 2 * w() * z(),
|
||||
2 * x() * z() + 2 * w() * y(),
|
||||
|
||||
2 * x() * y() + 2 * w() * z(),
|
||||
w() * w() - x() * x() + y() * y() - z() * z(),
|
||||
2 * y() * z() - 2 * w() * x(),
|
||||
|
||||
2 * x() * z() - 2 * w() * y(),
|
||||
2 * y() * z() + 2 * w() * x(),
|
||||
w() * w() - x() * x() - y() * y() + z() * z()
|
||||
);
|
||||
return m;
|
||||
}
|
||||
|
||||
/// Computes a 4x4 rotation matrix (row-major representation)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix4x4<Element> as_rotation_matrix_4x4() const {
|
||||
Matrix4x4<Element> m = Matrix4x4<Element>::identity();
|
||||
m.set_slice_3x3(as_rotation_matrix_3x3());
|
||||
return m;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constructs a quaternion that is non-zero only in its real element.
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> make_Quaternion(
|
||||
Element w) { ///< real part
|
||||
|
||||
return Quaternion<Element>(w);
|
||||
}
|
||||
|
||||
/// Constructs a quaternion from a vector and real
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> make_Quaternion(
|
||||
Matrix3x1<Element> const &imag, ///< imaginary party as a vector
|
||||
Element w) { ///< real part
|
||||
|
||||
return Quaternion<Element>(imag, w);
|
||||
}
|
||||
|
||||
/// Constructs a quaternion from a unit-length rotation axis and a rotation
|
||||
/// angle in radians
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> make_QuaternionRotation(
|
||||
Matrix3x1<Element> const &axis_unit, ///< rotation axis (unit-length)
|
||||
Element w) { ///< rotation angle in radians
|
||||
|
||||
return Quaternion<Element>::rotation(axis_unit, w);
|
||||
}
|
||||
|
||||
/// Constructs a quaternion q = xi + yj + zk + w
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> make_Quaternion(Element x, Element y, Element z, Element w) {
|
||||
return Quaternion<Element>(x, y, z, w);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Returns the magnitude of the complex number
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Element abs(Quaternion<Element> const &q) {
|
||||
return fast_sqrt(norm(q));
|
||||
}
|
||||
|
||||
/// Quaternion conjugate
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> conj(Quaternion<Element> const &q) {
|
||||
return make_Quaternion(
|
||||
-q.x(),
|
||||
-q.y(),
|
||||
-q.z(),
|
||||
q.w()
|
||||
);
|
||||
}
|
||||
|
||||
/// Computes the squared magnitude of the quaternion
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Element norm(Quaternion<Element> const &q) {
|
||||
return q.x() * q.x() + q.y() * q.y() + q.z() * q.z() + q.w() * q.w();
|
||||
}
|
||||
|
||||
/// Quaternion reciprocal
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> reciprocal(Quaternion<Element> const &q) {
|
||||
|
||||
Element nsq = norm(q);
|
||||
|
||||
return make_Quaternion(
|
||||
-q.x() / nsq,
|
||||
-q.y() / nsq,
|
||||
-q.z() / nsq,
|
||||
q.w() / nsq
|
||||
);
|
||||
}
|
||||
|
||||
/// Returns a unit-length quaternion
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> unit(Quaternion<Element> const &q) {
|
||||
|
||||
Element rcp_mag = Element(1) / abs(q);
|
||||
|
||||
return make_Quaternion(
|
||||
q.x() * rcp_mag,
|
||||
q.y() * rcp_mag,
|
||||
q.z() * rcp_mag,
|
||||
q.w() * rcp_mag
|
||||
);
|
||||
}
|
||||
|
||||
/// Quaternion exponential
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> exp(Quaternion<Element> const &q) {
|
||||
|
||||
Element exp_ = fast_exp(q.w());
|
||||
Element imag_norm = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z());
|
||||
Element sin_norm = fast_sin(imag_norm);
|
||||
|
||||
return make_Quaternion(
|
||||
exp_ * q.x() * sin_norm / imag_norm,
|
||||
exp_ * q.y() * sin_norm / imag_norm,
|
||||
exp_ * q.z() * sin_norm / imag_norm,
|
||||
exp_ * fast_cos(imag_norm)
|
||||
);
|
||||
}
|
||||
|
||||
/// Quaternion natural logarithm
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> log(Quaternion<Element> const &q) {
|
||||
|
||||
Element v = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z());
|
||||
Element s = fast_acos(q.w() / abs(q)) / v;
|
||||
|
||||
return make_Quaternion(
|
||||
q.x() * s,
|
||||
q.y() * s,
|
||||
q.z() * s,
|
||||
fast_log(q.w())
|
||||
);
|
||||
}
|
||||
|
||||
/// Gets the rotation angle from a unit-length quaternion
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Element get_rotation_angle(Quaternion<Element> const &q_unit) {
|
||||
return fast_acos(q_unit.w()) * Element(2);
|
||||
}
|
||||
|
||||
/// Gets the rotation axis from a unit-length quaternion
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix3x1<Element> get_rotation_axis(Quaternion<Element> const &q_unit) {
|
||||
return q_unit.pure().unit();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Equality operator
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
|
||||
return lhs.x() == rhs.x() &&
|
||||
lhs.y() == rhs.y() &&
|
||||
lhs.z() == rhs.z() &&
|
||||
lhs.w() == rhs.w();
|
||||
}
|
||||
|
||||
/// Inequality operator
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
/// Quaternion scalar multiplication
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> operator*(Quaternion<Element> q, Element s) {
|
||||
return make_Quaternion(
|
||||
q.x() * s,
|
||||
q.y() * s,
|
||||
q.z() * s,
|
||||
q.w() * s
|
||||
);
|
||||
}
|
||||
|
||||
/// Quaternion scalar multiplication
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> operator*(Element s, Quaternion<Element> const &q) {
|
||||
return make_Quaternion(
|
||||
s * q.x(),
|
||||
s * q.y(),
|
||||
s * q.z(),
|
||||
s * q.w()
|
||||
);
|
||||
}
|
||||
|
||||
/// Quaternion scalar division
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> operator/(Quaternion<Element> const &q, Element s) {
|
||||
return make_Quaternion(
|
||||
q.x() / s,
|
||||
q.y() / s,
|
||||
q.z() / s,
|
||||
q.w() / s
|
||||
);
|
||||
}
|
||||
|
||||
/// Quaternion unary negation
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> operator-(Quaternion<Element> const &q) {
|
||||
return make_Quaternion(
|
||||
-q.x(),
|
||||
-q.y(),
|
||||
-q.z(),
|
||||
-q.w()
|
||||
);
|
||||
}
|
||||
|
||||
/// Quaternion addition
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> operator+(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
|
||||
return make_Quaternion(
|
||||
lhs.x() + rhs.x(),
|
||||
lhs.y() + rhs.y(),
|
||||
lhs.z() + rhs.z(),
|
||||
lhs.w() + rhs.w()
|
||||
);
|
||||
}
|
||||
|
||||
/// Quaternion subtraction
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> operator-(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
|
||||
return make_Quaternion(
|
||||
lhs.x() - rhs.x(),
|
||||
lhs.y() - rhs.y(),
|
||||
lhs.z() - rhs.z(),
|
||||
lhs.w() - rhs.w()
|
||||
);
|
||||
}
|
||||
|
||||
/// Quaternion product
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> operator*(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
|
||||
return make_Quaternion(
|
||||
lhs.w() * rhs.x() + rhs.w() * lhs.x() + lhs.y() * rhs.z() - lhs.z() * rhs.y(),
|
||||
lhs.w() * rhs.y() + rhs.w() * lhs.y() + lhs.z() * rhs.x() - lhs.x() * rhs.z(),
|
||||
lhs.w() * rhs.z() + rhs.w() * lhs.z() + lhs.x() * rhs.y() - lhs.y() * rhs.x(),
|
||||
lhs.w() * rhs.w() - lhs.x() * rhs.x() - lhs.y() * rhs.y() - lhs.z() * rhs.z()
|
||||
);
|
||||
}
|
||||
|
||||
/// Quaternion division
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> operator/(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
|
||||
return lhs * reciprocal(rhs);
|
||||
}
|
||||
|
||||
/// Quaternion scalar division
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Quaternion<Element> operator/(Element s, Quaternion<Element> const &q) {
|
||||
return s * reciprocal(q);
|
||||
}
|
||||
|
||||
/// Rotates a 3-vector assuming this is a unit quaternion (a spinor). This avoids computing
|
||||
/// a reciprocal.
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix3x1<Element> spinor_rotation(
|
||||
Quaternion<Element> const &spinor, /// unit-length quaternion
|
||||
Matrix3x1<Element> const &rhs) { /// arbitrary 3-vector
|
||||
|
||||
return (spinor * Quaternion<Element>(rhs, 0) * conj(spinor)).pure();
|
||||
}
|
||||
|
||||
/// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor). This avoids computing
|
||||
/// a reciprocal.
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix3x1<Element> spinor_rotation_inv(
|
||||
Quaternion<Element> const &spinor, /// unit-length quaternion
|
||||
Matrix3x1<Element> const &rhs) { /// arbitrary 3-vector
|
||||
|
||||
return (conj(spinor) * Quaternion<Element>(rhs, 0) * spinor).pure();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Output operators
|
||||
//
|
||||
|
||||
template <typename Element>
|
||||
std::ostream &operator<<(std::ostream &out, Quaternion<Element> const &q) {
|
||||
return out << q.w() << "+i" << q.x() << "+j" << q.y() << "+k" << q.z();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -22,6 +22,11 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/**
|
||||
\file
|
||||
\brief This class provides helpers to support real<> and complex<> types in generic code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -1,179 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements a software-pipelined efficient batched reduction.
|
||||
D = alpha * Reduction(A) + beta * C
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reduction {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename batched_reduction_>
|
||||
__global__ __launch_bounds__(batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_::Params params) {
|
||||
// Construct the batched_reduction object
|
||||
batched_reduction_ batched_reduction(params);
|
||||
batched_reduction.run();
|
||||
}
|
||||
|
||||
template <typename BatchedReductionTraits_>
|
||||
struct BatchedReduction {
|
||||
/// This class
|
||||
typedef BatchedReduction<BatchedReductionTraits_> This_;
|
||||
/// The traits
|
||||
typedef BatchedReductionTraits_ Traits;
|
||||
/// Params
|
||||
typedef typename Traits::Params Params;
|
||||
/// functor
|
||||
typedef typename Traits::Functor Functor;
|
||||
|
||||
/// ctor
|
||||
CUTLASS_DEVICE BatchedReduction(Params const ¶ms_)
|
||||
: params(params_), functor(params_.functorParams) {}
|
||||
|
||||
/// main operation method
|
||||
/// D = alpha * Reduction(A) + beta * C
|
||||
CUTLASS_DEVICE void run() {
|
||||
#if (__CUDA_ARCH__ >= 600)
|
||||
// Swizzle the IDs of the block
|
||||
typename Traits::BlockSwizzle block_swizzle;
|
||||
Coord<3> threadblock_offset =
|
||||
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::SubTile>());
|
||||
|
||||
int subTileSize = gridDim.x * Traits::SubTile::kW;
|
||||
int tileSize = params.problem_size[1] * params.problem_size[2];
|
||||
int subTileOffset = threadblock_offset[2] + threadIdx.x * Traits::ThreadShape::kW;
|
||||
|
||||
int subTileBase = 0;
|
||||
|
||||
typename Traits::ScalarA inRegs[Traits::maxInReg];
|
||||
typename Traits::ScalarAccum AccumRegs[Traits::maxOutReg];
|
||||
#pragma unroll
|
||||
for (int subTile = 0; subTile < tileSize; subTile += subTileSize) {
|
||||
int tileOffset = subTileBase + subTileOffset;
|
||||
// Init AccumRegs
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Traits::ThreadShape::kW; i++)
|
||||
AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(0.0f);
|
||||
// Fetch c0
|
||||
typename Traits::ScalarAccum c0[Traits::ThreadShape::kW];
|
||||
#pragma unroll
|
||||
for (int i = 0; i< Traits::ThreadShape::kW; i++)
|
||||
c0[i] = static_cast<typename Traits::ScalarAccum>(params.d_c[tileOffset + i]);
|
||||
|
||||
// Fetch partial sums from A
|
||||
#pragma unroll
|
||||
for (int s = 0; s < Traits::ReductionSize; s++) {
|
||||
int inRegOffset = s * Traits::ThreadShape::kW;
|
||||
int dOffset = (s * tileSize) + tileOffset;
|
||||
#pragma unroll
|
||||
for (int i = 0; i< Traits::ThreadShape::kW; i++) {
|
||||
inRegs[inRegOffset + i] = params.d_a[dOffset + i];
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate
|
||||
#pragma unroll
|
||||
for (int s = 0; s < Traits::ReductionSize; s++) {
|
||||
int inRegOffset = s * Traits::ThreadShape::kW;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
|
||||
//AccumRegs[i] = cuFma(params.alpha, inRegs[inRegOffset + i], AccumRegs[i]);
|
||||
//AccumRegs[i] = params.alpha * inRegs[inRegOffset + i] + AccumRegs[i];
|
||||
AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(inRegs[inRegOffset + i]) + AccumRegs[i];
|
||||
}
|
||||
}
|
||||
// calling functor
|
||||
functor_caller<Traits::ThreadShapeMultiple2>(AccumRegs, c0, AccumRegs);
|
||||
|
||||
// Store AccumRegs to D
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
|
||||
params.d_d[tileOffset + i] = static_cast<typename Traits::ScalarD>(AccumRegs[i]);
|
||||
}
|
||||
|
||||
// Advance sub-tile pointer
|
||||
subTileBase += subTileSize;
|
||||
} // end for loop
|
||||
#endif //#if (__CUDA_ARCH__ >= 600)
|
||||
}
|
||||
|
||||
template<bool ThreadShapeMultiple2>
|
||||
CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output) {
|
||||
if (ThreadShapeMultiple2 == true) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Traits::ThreadShape::kW / 2; i++) {
|
||||
functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 2>(&accum[2 * i], &old[2 * i], &output[2 * i]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
|
||||
functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 1>(&accum[i], &old[i], &output[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Static function members
|
||||
//
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Launch the kernel.
|
||||
static __host__ cudaError_t launch(Params const& params,
|
||||
cudaStream_t stream = cudaStreamDefault) {
|
||||
// Setup the grid.
|
||||
typename Traits::BlockSwizzle block_swizzle;
|
||||
dim3 grid = block_swizzle.get_grid_layout(params.problem_size,
|
||||
make_Coord_from_shape<typename Traits::OutputTile>());
|
||||
|
||||
dim3 block;
|
||||
block.x = Traits::kThreads;
|
||||
batched_reduction_kernel<This_><<<grid, block, 0, stream>>>(params);
|
||||
return cudaGetLastError();
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
// The functor.
|
||||
Functor functor;
|
||||
};
|
||||
|
||||
} // namespace reduction
|
||||
} // namespace cutlass
|
||||
@ -1,192 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of complete batched reduction.
|
||||
D = alpha * Reduction(A) + beta * C
|
||||
*/
|
||||
#pragma once
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/shape.h"
|
||||
#include "cutlass/reduction/threadblock_swizzle.h"
|
||||
#include "cutlass/reduction/batched_reduction.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace reduction {
|
||||
|
||||
/*
|
||||
OutputTile defines the work load per thread block
|
||||
Subtile defines the work load per thread block per iteration
|
||||
OutputTile / Subtile = number of iterations within a kernel
|
||||
ThreadShape defines the work load per thread
|
||||
Subtile / ThreadShape = number of threads per thread block
|
||||
*/
|
||||
template <
|
||||
/// The scalar type for A
|
||||
typename ScalarA_,
|
||||
/// The scalar type for C
|
||||
typename ScalarC_,
|
||||
/// The scalar type for D
|
||||
typename ScalarD_,
|
||||
/// the scalar type for alpha,
|
||||
typename ScalarAlphaBeta_,
|
||||
/// The scalar type for accumulator
|
||||
typename ScalarAccum_,
|
||||
/// Reduction work load per batch
|
||||
int ReductionSize_ = 1,
|
||||
/// The output tile, work load per thread block,
|
||||
typename OutputTile_ = Shape<1, 1, 128>,
|
||||
/// The subtile
|
||||
typename SubTile_ = Shape<1, 1, 64>,
|
||||
/// Work load per thread, per subtile
|
||||
typename ThreadShape_ = Shape<1, 1, 2>,
|
||||
/// The index
|
||||
typename Index_ = int,
|
||||
/// The block swizzle to reorganize the grid.
|
||||
typename BlockSwizzle_ = DefaultBlockSwizzle,
|
||||
/// The input register vector size in kernel
|
||||
int maxInReg_ = 160,
|
||||
/// The output register vector size in kernel
|
||||
int maxOutReg_ = 64,
|
||||
/// The functor that will be executed at the end
|
||||
typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >
|
||||
>
|
||||
struct BatchedReductionTraits {
|
||||
///
|
||||
typedef BatchedReductionTraits<ScalarA_,
|
||||
ScalarC_,
|
||||
ScalarD_,
|
||||
ScalarAlphaBeta_,
|
||||
ScalarAccum_,
|
||||
ReductionSize_,
|
||||
OutputTile_,
|
||||
SubTile_,
|
||||
ThreadShape_,
|
||||
Index_,
|
||||
BlockSwizzle_,
|
||||
maxInReg_,
|
||||
maxOutReg_,
|
||||
Functor_> This_;
|
||||
/// The struct that consumes this Traits
|
||||
typedef typename cutlass::reduction::BatchedReduction<This_> KernelClass;
|
||||
///
|
||||
typedef OutputTile_ OutputTile;
|
||||
///
|
||||
typedef SubTile_ SubTile;
|
||||
///
|
||||
typedef ThreadShape_ ThreadShape;
|
||||
/// The input pointer type
|
||||
typedef ScalarA_ ScalarA;
|
||||
///
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The output pointer type
|
||||
typedef ScalarD_ ScalarD;
|
||||
/// The alpha beta type
|
||||
typedef ScalarAlphaBeta_ ScalarAlphaBeta;
|
||||
/// The type for accumulation
|
||||
typedef ScalarAccum_ ScalarAccum;
|
||||
/// The index
|
||||
typedef Index_ Index;
|
||||
/// The thread block swizzle
|
||||
typedef BlockSwizzle_ BlockSwizzle;
|
||||
///
|
||||
static const int ReductionSize = ReductionSize_;
|
||||
/// check if threadShape is multiple of 2.
|
||||
static const bool ThreadShapeMultiple2 = (ThreadShape::kW % 2 == 0);
|
||||
///
|
||||
typedef Functor_ Functor;
|
||||
/// Parameteres object constructable on the host
|
||||
/// The number of threads per thread block. can be deduced
|
||||
static int const kThreads = SubTile::kW / ThreadShape::kW;
|
||||
//
|
||||
static int const maxInReg = maxInReg_;
|
||||
//
|
||||
static int const maxOutReg = maxOutReg_;
|
||||
//
|
||||
static_assert(SubTile::kW % ThreadShape::kW == 0, "cannot evenly distribute work load among threads");
|
||||
//
|
||||
static_assert(kThreads % 32 == 0, "threads per threadblock is not multiple of 32");
|
||||
//
|
||||
static_assert(OutputTile::kW % SubTile::kW == 0, "cannot evenly distribute work load among iterations");
|
||||
//
|
||||
static_assert(ReductionSize * ThreadShape::kW <= maxInReg, "ReductionSize * ThreadShape::kW should not be bigger than maxInReg");
|
||||
//
|
||||
static_assert(ThreadShape::kW <= maxOutReg, "ThreadShape::kW should not be bigger than maxOutReg");
|
||||
|
||||
struct Params {
|
||||
/// The dimension of output tensor
|
||||
Coord<3> problem_size;
|
||||
/// The alpha
|
||||
ScalarAlphaBeta alpha;
|
||||
/// The beta
|
||||
ScalarAlphaBeta beta;
|
||||
/// stride between two element that will be sumed
|
||||
long long int reduction_stride;
|
||||
//
|
||||
ScalarA const *d_a;
|
||||
//
|
||||
Index lda;
|
||||
//
|
||||
ScalarC const *d_c;
|
||||
//
|
||||
Index ldc;
|
||||
//
|
||||
ScalarD *d_d;
|
||||
//
|
||||
Index ldd;
|
||||
/// The functor params.
|
||||
typename Functor::Params functorParams;
|
||||
/// Initialize the parameters for 2D output tensor
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m_,
|
||||
Index n_,
|
||||
ScalarAlphaBeta alpha_,
|
||||
ScalarAlphaBeta beta_,
|
||||
long long int reduction_stride_,
|
||||
ScalarA const *d_a_,
|
||||
Index lda_,
|
||||
ScalarC const *d_c_,
|
||||
Index ldc_,
|
||||
ScalarD *d_d_,
|
||||
Index ldd_){
|
||||
problem_size = make_Coord(1, n_, m_);
|
||||
alpha = alpha_;
|
||||
beta = beta_;
|
||||
reduction_stride = reduction_stride_;
|
||||
d_a = d_a_;
|
||||
lda = lda_;
|
||||
d_c = d_c_;
|
||||
d_d = d_d_;
|
||||
ldc = ldc_;
|
||||
ldd = ldd_;
|
||||
|
||||
functorParams.initialize(alpha_, beta_);
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
};
|
||||
} // namespace reduction
|
||||
} // namespace cutlass
|
||||
@ -77,6 +77,18 @@ bool relatively_equal<uint1b_t>(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) {
|
||||
return (a == b);
|
||||
}
|
||||
|
||||
template <>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool relatively_equal<int2b_t>(int2b_t a, int2b_t b, int2b_t, int2b_t) {
|
||||
return (a == b);
|
||||
}
|
||||
|
||||
template <>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool relatively_equal<uint2b_t>(uint2b_t a, uint2b_t b, uint2b_t, uint2b_t) {
|
||||
return (a == b);
|
||||
}
|
||||
|
||||
template <>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool relatively_equal<int4b_t>(int4b_t a, int4b_t b, int4b_t, int4b_t) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user