27
CHANGELOG.md
27
CHANGELOG.md
@ -2,6 +2,33 @@
|
||||
|
||||
# CUTLASS 2.x
|
||||
|
||||
## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
|
||||
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
* Adopt the new L2 prefetch feature in [cp.async](/include/cutlass/arch/memory.h) and [global load](/include/cutlass/arch/memory_sm80.h)
|
||||
* Fused operators with GEMM and Convolution
|
||||
* [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* [Fused partial reduction in epilogue](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* 64b tensor strides and leading dimensions support for GEMMs
|
||||
* Affine rank=2 matrix layouts
|
||||
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](/include/cutlass/layout/matrix.h)
|
||||
* Support [FP64 tensor core](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
|
||||
* [Batched GEMV](/test/unit/gemm/device/gemv.cu) preview implementation
|
||||
* [New strided Dgrad](test/unit/gemm/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation
|
||||
* Accelerates over previous implementation by cutting down redundant math by 4x
|
||||
* Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` iterface
|
||||
* Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores)
|
||||
* Updates to [quaternion.h](/include/cutlass/quaternion.h) and [functional.h](/include/cutlass/functional.h)
|
||||
* SDK Example for [GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_gemm/quaternion_conv.cu)
|
||||
* [Unit tests for GEMM](/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
|
||||
* Many improvements to the epilogue.
|
||||
* Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Performance improvement for FP16 tensor core kernels
|
||||
* Bug fixes
|
||||
* Updated minimum CUDA Toolkit requirement to 10.2
|
||||
* [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
## [2.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.5.0) (2021-02-26)
|
||||
* Tensor reductions
|
||||
* _m_-to-_n_ reductions of tensors with affine layout
|
||||
|
||||
@ -32,9 +32,15 @@ endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.5.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.6.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 10.2)
|
||||
message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 10.2 or higher, and strongly recommends CUDA 11.0 or higher.")
|
||||
elseif (CUDA_VERSION VERSION_LESS 11.0)
|
||||
message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.0 or higher.")
|
||||
endif()
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
|
||||
#
|
||||
@ -105,7 +111,7 @@ endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.1)
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.1 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86)
|
||||
endif()
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
@ -275,7 +281,14 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
|
||||
endif()
|
||||
|
||||
# There are numerous Clang versions that can work with each CUDA toolkit and the
|
||||
# the checks are not very useful so we are turning them off and using testing to
|
||||
# ensure the various combinations work properly.
|
||||
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__NV_NO_HOST_COMPILER_CHECK=1)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unknown-cuda-version)
|
||||
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -pragma-unroll-threshold=100000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)
|
||||
@ -294,18 +307,28 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
link_libraries(nvidia::cudart)
|
||||
endif()
|
||||
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
|
||||
# CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this
|
||||
# property for CMake 3.18+, so we request the NEW behavior for correct compatibility.
|
||||
# https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104
|
||||
cmake_policy(SET CMP0104 NEW)
|
||||
endif()
|
||||
|
||||
function(cutlass_apply_cuda_gencode_flags TARGET)
|
||||
|
||||
set(NVCC_FLAGS)
|
||||
set(CLANG_FLAGS)
|
||||
set(__CMAKE_CUDA_ARCHS)
|
||||
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
|
||||
list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
|
||||
set(CODES)
|
||||
if(CUTLASS_NVCC_EMBED_CUBIN)
|
||||
list(APPEND CODES sm_${ARCH})
|
||||
list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-real)
|
||||
endif()
|
||||
if(CUTLASS_NVCC_EMBED_PTX)
|
||||
list(APPEND CODES compute_${ARCH})
|
||||
list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-virtual)
|
||||
endif()
|
||||
list(JOIN CODES "," CODES_STR)
|
||||
list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}])
|
||||
@ -317,6 +340,8 @@ function(cutlass_apply_cuda_gencode_flags TARGET)
|
||||
PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:${CLANG_FLAGS}>
|
||||
)
|
||||
elseif(CMAKE_VERSION GREATER_EQUAL 3.18)
|
||||
set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS})
|
||||
else()
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
@ -542,10 +567,14 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
#
|
||||
|
||||
set(options DISABLE_EXECUTABLE_INSTALL_RULE)
|
||||
set(oneValueArgs)
|
||||
set(oneValueArgs DISABLE_TESTS)
|
||||
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
|
||||
if (NOT DEFINED __DISABLE_TESTS)
|
||||
set(__DISABLE_TESTS OFF)
|
||||
endif()
|
||||
|
||||
if (NOT __DISABLE_EXECUTABLE_INSTALL_RULE AND CUTLASS_INSTALL_TESTS)
|
||||
|
||||
# file(RELATIVE_PATH CMAKE_CURRENT_BINARY_RELATIVE_DIR ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR})
|
||||
@ -610,6 +639,8 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
|
||||
)
|
||||
|
||||
set_tests_properties(c${TEST_NAME} PROPERTIES DISABLED ${__DISABLE_TESTS})
|
||||
|
||||
if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
# To run the tests from an install package with tests enabled, we need to generate test files
|
||||
|
||||
33
README.md
33
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 2.5
|
||||
# CUTLASS 2.6
|
||||
|
||||
_CUTLASS 2.5 - February 2021_
|
||||
_CUTLASS 2.6 - July 2021_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
@ -34,12 +34,24 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.6
|
||||
CUTLASS 2.6 is a minor update to CUTLASS adding:
|
||||
- Fused [broadcast](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) and [reductions](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) in the epilogues of GEMM and Convolution
|
||||
- [Quaternion-valued GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_conv/quaternion_conv.cu) in single-precision
|
||||
- [New strided Dgrad](test/unit/gemm/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation offers up to 4x performance improvements over previous strided Dgrad
|
||||
- 64-bit strides for large tensor allocations
|
||||
- [General affine layouts](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) fp64 tensor core and simt GEMM
|
||||
- Enhanced functionality, boosted performance, and bug fixes in the epilogue.
|
||||
- Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
- Adopt new L2 prefetch feature in [ptx instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-isa-version-7-4).
|
||||
- Numerous updates from the community (thanks!)
|
||||
- See the [CHANGELOG](CHANGELOG.md) for more details
|
||||
|
||||
# What's New in CUTLASS 2.5
|
||||
CUTLASS 2.5 is a minor update to CUTLASS adding:
|
||||
- [Tensor reductions](/test/unit/reduction/device/tensor_reduce_contiguous.cu)
|
||||
- [Optimizations for 3-D convolution](include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h)
|
||||
- [Fused Convolution+Convolution example](/examples/13_two_tensor_op_fusion/README.md)
|
||||
- See the [CHANGELOG](CHANGELOG.md) for more details
|
||||
|
||||
# What's New in CUTLASS 2.4
|
||||
CUTLASS 2.4 is a significant update to CUTLASS adding:
|
||||
@ -52,7 +64,7 @@ CUTLASS 2.4 is a significant update to CUTLASS adding:
|
||||
CUTLASS 2.3 is a minor update to CUTLASS adding:
|
||||
- GEMMs targeting structured [Sparse Tensor Cores](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) in NVIDIA Ampere Architecture GPUs
|
||||
- Fast SGEMM kernels targeting GeForce RTX 30-series CUDA Cores
|
||||
- Intended to be compiled with [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
- Intended to be compiled with [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit) or later
|
||||
|
||||
# What's New in CUTLASS 2.2
|
||||
|
||||
@ -62,7 +74,7 @@ CUTLASS 2.2 is a significant update to CUTLASS adding:
|
||||
- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types
|
||||
- Deep software pipelines using asynchronous copy
|
||||
- Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745)
|
||||
- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) or later
|
||||
|
||||
# What's New in CUTLASS 2.1
|
||||
|
||||
@ -95,8 +107,8 @@ using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++11 host compiler and
|
||||
performs best when built with the [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, CUDA 10.2, and CUDA 11.0.
|
||||
performs best when built with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 10.2, CUDA 11.0, CUDA 11.1, CUDA 11.2, and CUDA 11.3.
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
@ -106,12 +118,16 @@ We have tested the following environments.
|
||||
| | Microsoft Visual Studio 2017|
|
||||
| Ubuntu 16.04 | GCC 5.4.0 |
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.2.0 |
|
||||
|
||||
Additionally, CUTLASS may be built with clang.
|
||||
See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
|
||||
For all GPUs, we recommend compiling with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
for best performance.
|
||||
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|---|---|---|---|
|
||||
@ -511,6 +527,7 @@ CUTLASS is released by NVIDIA Corporation as Open Source software under the
|
||||
|
||||
The official list of CUTLASS developers and contributors is available here: [CONTRIBUTORS](CONTRIBUTORS.md).
|
||||
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
@ -17,3 +17,5 @@ add_test("@TEST_NAME@" ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}"
|
||||
if (NOT "@TEST_EXE_WORKING_DIRECTORY@" STREQUAL "")
|
||||
set_tests_properties("@TEST_NAME@" PROPERTIES WORKING_DIRECTORY "@TEST_EXE_WORKING_DIRECTORY@")
|
||||
endif()
|
||||
|
||||
set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)
|
||||
|
||||
@ -119,12 +119,12 @@ cudaError_t cutlass_hgemm_nn(
|
||||
int K,
|
||||
cutlass::half_t alpha,
|
||||
cutlass::half_t const *A,
|
||||
int lda,
|
||||
cutlass::layout::ColumnMajor::Stride::Index lda,
|
||||
cutlass::half_t const *B,
|
||||
int ldb,
|
||||
cutlass::layout::ColumnMajor::Stride::Index ldb,
|
||||
cutlass::half_t beta,
|
||||
cutlass::half_t *C,
|
||||
int ldc) {
|
||||
cutlass::layout::ColumnMajor::Stride::Index ldc) {
|
||||
|
||||
// Define the GEMM operation
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
|
||||
@ -67,7 +67,7 @@ beta * C).
|
||||
Now that we setup the properties of data, we have to setup properties of computation.
|
||||
|
||||
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x32,
|
||||
64x64x4, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally
|
||||
64x64x32, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally
|
||||
deduce the amount of threads needed per thread-block, amount of shared memory, storing data in
|
||||
bank-conflict free manner, and ton of other variables required to compose, intialize and launch a
|
||||
high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from
|
||||
|
||||
@ -275,10 +275,10 @@ public:
|
||||
int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
|
||||
int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
|
||||
int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
|
||||
int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
|
||||
typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
|
||||
typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
|
||||
int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k();
|
||||
int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n();
|
||||
|
||||
@ -292,10 +292,11 @@ public:
|
||||
int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
|
||||
int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
|
||||
int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
|
||||
int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
|
||||
typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
|
||||
typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
|
||||
|
||||
int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k();
|
||||
int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n();
|
||||
|
||||
@ -48,6 +48,10 @@ addition to its own input activation tile. Therefore the input activation warp t
|
||||
2nd GEMM/Conv only depends on the output warp accumulator of the 1st GEMM/Conv in the
|
||||
register file, and the operation can be fully register-file-resident.
|
||||
|
||||
When applying the above constraint to convolutions, it is required that the 2nd Convolution
|
||||
kernel doesn't have halos such that data used by each threadblock doesn't depend on any other
|
||||
threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without any paddings.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
@ -36,8 +36,6 @@
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 (
|
||||
@ -57,7 +55,7 @@ cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 (
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
|
||||
void run_nonfused_conv2d_fprop_f16_sm75() {
|
||||
bool run_nonfused_conv2d_fprop_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@ -90,7 +88,8 @@ void run_nonfused_conv2d_fprop_f16_sm75() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -135,9 +134,10 @@ void run_nonfused_conv2d_fprop_f16_sm75() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_conv2d_fprop_f16_sm75() {
|
||||
bool run_fused_conv2d_fprop_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@ -161,7 +161,8 @@ void run_fused_conv2d_fprop_f16_sm75() {
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -207,9 +208,10 @@ void run_fused_conv2d_fprop_f16_sm75() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@ -242,7 +244,8 @@ void run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -287,9 +290,10 @@ void run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_conv2d_fprop_optimized_f16_sm75() {
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@ -313,7 +317,8 @@ void run_fused_conv2d_fprop_optimized_f16_sm75() {
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -359,10 +364,8 @@ void run_fused_conv2d_fprop_optimized_f16_sm75() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
|
||||
@ -36,8 +36,6 @@
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 (
|
||||
@ -57,7 +55,7 @@ cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 (
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
|
||||
void run_nonfused_conv2d_fprop_f16_sm80() {
|
||||
bool run_nonfused_conv2d_fprop_f16_sm80() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@ -90,7 +88,8 @@ void run_nonfused_conv2d_fprop_f16_sm80() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -135,9 +134,10 @@ void run_nonfused_conv2d_fprop_f16_sm80() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_conv2d_fprop_f16_sm80() {
|
||||
bool run_fused_conv2d_fprop_f16_sm80() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@ -161,7 +161,8 @@ void run_fused_conv2d_fprop_f16_sm80() {
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -205,9 +206,10 @@ void run_fused_conv2d_fprop_f16_sm80() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@ -240,7 +242,8 @@ void run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -285,9 +288,10 @@ void run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
@ -311,7 +315,8 @@ void run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -355,9 +360,8 @@ void run_fused_conv2d_fprop_optimized_f16_sm80() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
|
||||
@ -36,8 +36,6 @@
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
|
||||
@ -57,7 +55,7 @@ cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 (
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
|
||||
void run_nonfused_conv2d_fprop_s8_sm75() {
|
||||
bool run_nonfused_conv2d_fprop_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -90,7 +88,8 @@ void run_nonfused_conv2d_fprop_s8_sm75() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -135,9 +134,10 @@ void run_nonfused_conv2d_fprop_s8_sm75() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_conv2d_fprop_s8_sm75() {
|
||||
bool run_fused_conv2d_fprop_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -161,7 +161,8 @@ void run_fused_conv2d_fprop_s8_sm75() {
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -207,9 +208,10 @@ void run_fused_conv2d_fprop_s8_sm75() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -242,7 +244,8 @@ void run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -287,9 +290,10 @@ void run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -313,7 +317,8 @@ void run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -359,9 +364,8 @@ void run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
|
||||
@ -36,8 +36,6 @@
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 (
|
||||
@ -57,7 +55,7 @@ cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 (
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
|
||||
void run_nonfused_conv2d_fprop_s8_sm80() {
|
||||
bool run_nonfused_conv2d_fprop_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -90,7 +88,8 @@ void run_nonfused_conv2d_fprop_s8_sm80() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -135,9 +134,10 @@ void run_nonfused_conv2d_fprop_s8_sm80() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_conv2d_fprop_s8_sm80() {
|
||||
bool run_fused_conv2d_fprop_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -161,7 +161,8 @@ void run_fused_conv2d_fprop_s8_sm80() {
|
||||
ElementC,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -207,9 +208,10 @@ void run_fused_conv2d_fprop_s8_sm80() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -242,7 +244,8 @@ void run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -287,9 +290,10 @@ void run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
@ -313,7 +317,8 @@ void run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
ElementC,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -359,10 +364,9 @@ void run_fused_conv2d_fprop_optimized_s8_sm80() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
|
||||
@ -39,14 +39,12 @@
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*1600, 128, 64);
|
||||
|
||||
void run_nonfused_gemm_f16() {
|
||||
bool run_nonfused_gemm_f16() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
@ -80,7 +78,8 @@ void run_nonfused_gemm_f16() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -116,9 +115,11 @@ void run_nonfused_gemm_f16() {
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_gemm_f16() {
|
||||
bool run_fused_gemm_f16() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
@ -140,7 +141,8 @@ void run_fused_gemm_f16() {
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -183,7 +185,6 @@ void run_fused_gemm_f16() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif //#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
@ -39,14 +39,12 @@
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*1600, 128, 64);
|
||||
|
||||
void run_nonfused_gemm_f16_sm80() {
|
||||
bool run_nonfused_gemm_f16_sm80() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
@ -80,7 +78,8 @@ void run_nonfused_gemm_f16_sm80() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
@ -116,9 +115,11 @@ void run_nonfused_gemm_f16_sm80() {
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_gemm_f16_sm80() {
|
||||
bool run_fused_gemm_f16_sm80() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
@ -140,7 +141,8 @@ void run_fused_gemm_f16_sm80() {
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -183,7 +185,7 @@ void run_fused_gemm_f16_sm80() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif //#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
@ -39,14 +39,12 @@
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*1600, 128, 64);
|
||||
|
||||
void run_nonfused_gemm_s8() {
|
||||
bool run_nonfused_gemm_s8() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
@ -80,7 +78,8 @@ void run_nonfused_gemm_s8() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -116,9 +115,11 @@ void run_nonfused_gemm_s8() {
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_gemm_s8() {
|
||||
bool run_fused_gemm_s8() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
@ -140,7 +141,8 @@ void run_fused_gemm_s8() {
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -151,8 +153,6 @@ void run_fused_gemm_s8() {
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
@ -183,7 +183,7 @@ void run_fused_gemm_s8() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
@ -39,14 +39,12 @@
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*1600, 128, 64);
|
||||
|
||||
void run_nonfused_gemm_s8_sm80() {
|
||||
bool run_nonfused_gemm_s8_sm80() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
@ -80,7 +78,8 @@ void run_nonfused_gemm_s8_sm80() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
@ -106,7 +105,8 @@ void run_nonfused_gemm_s8_sm80() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
@ -124,9 +124,11 @@ void run_nonfused_gemm_s8_sm80() {
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_gemm_s8_sm80() {
|
||||
bool run_fused_gemm_s8_sm80() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
@ -148,7 +150,8 @@ void run_fused_gemm_s8_sm80() {
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -156,11 +159,10 @@ void run_fused_gemm_s8_sm80() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
@ -183,8 +185,7 @@ void run_fused_gemm_s8_sm80() {
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
true
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
@ -196,7 +197,6 @@ void run_fused_gemm_s8_sm80() {
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
|
||||
@ -115,9 +115,7 @@ template <
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator,
|
||||
/// Whether Beta is zero or not
|
||||
bool IsBetaZero = false>
|
||||
ElementAccumulator_>::Operator>
|
||||
class B2bGemm {
|
||||
public:
|
||||
|
||||
@ -148,7 +146,6 @@ class B2bGemm {
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static bool const kIsBetaZero = IsBetaZero;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
@ -175,8 +172,7 @@ class B2bGemm {
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator,
|
||||
kIsBetaZero
|
||||
Operator
|
||||
>::B2bGemmKernel;
|
||||
|
||||
/// Argument structure
|
||||
@ -422,7 +418,7 @@ public:
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace);
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
|
||||
@ -255,7 +255,7 @@ public:
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace);
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
|
||||
@ -28,53 +28,14 @@
|
||||
#include "b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm75.h"
|
||||
#include "b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.h"
|
||||
|
||||
int run() {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
std::cout << "Running on SM80" << std::endl;
|
||||
run_nonfused_conv2d_fprop_optimized_f16_sm80();
|
||||
run_fused_conv2d_fprop_optimized_f16_sm80();
|
||||
run_nonfused_conv2d_fprop_optimized_s8_sm80();
|
||||
run_fused_conv2d_fprop_optimized_s8_sm80();
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
std::cout << "Running on SM75" << std::endl;
|
||||
run_nonfused_conv2d_fprop_optimized_f16_sm75();
|
||||
run_fused_conv2d_fprop_optimized_f16_sm75();
|
||||
run_nonfused_conv2d_fprop_optimized_s8_sm75();
|
||||
run_fused_conv2d_fprop_optimized_s8_sm75();
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
int run_sm75() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Tensor Core operations used in this example must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
|
||||
notSupported = true;
|
||||
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
@ -85,10 +46,7 @@ int main() {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Tensor Ops used in this example must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
if (!(props.major == 7 && props.minor >= 5)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
@ -96,7 +54,83 @@ int main() {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run();
|
||||
|
||||
bool pass = 1;
|
||||
|
||||
std::cout << "Running on SM75" << std::endl;
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_f16_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_f16_sm75();
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_s8_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_s8_sm75();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
int run_sm80() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == 8 && props.minor >= 0)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = 1;
|
||||
|
||||
std::cout << "Running on SM80" << std::endl;
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_f16_sm80();
|
||||
pass &= run_fused_conv2d_fprop_optimized_f16_sm80();
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_s8_sm80();
|
||||
pass &= run_fused_conv2d_fprop_optimized_s8_sm80();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
int result = 0;
|
||||
|
||||
result = run_sm80();
|
||||
|
||||
if(!result) { // not supported
|
||||
result = run_sm75();
|
||||
|
||||
if(!result) {
|
||||
std::cout << "This example isn't supported on current architecture" << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(result >= 0)
|
||||
return 0;
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
@ -28,36 +28,15 @@
|
||||
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h"
|
||||
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h"
|
||||
|
||||
int run() {
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
std::cout << "Running on SM80" << std::endl;
|
||||
run_nonfused_gemm_f16_sm80();
|
||||
run_fused_gemm_f16_sm80();
|
||||
run_nonfused_gemm_s8_sm80();
|
||||
run_fused_gemm_s8_sm80();
|
||||
#elif defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
std::cout << "Running on SM75" << std::endl;
|
||||
run_nonfused_gemm_f16();
|
||||
run_fused_gemm_f16();
|
||||
run_nonfused_gemm_s8();
|
||||
run_fused_gemm_s8();
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
int run_sm75() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Tensor Core operations used in this example must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
|
||||
notSupported = true;
|
||||
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
@ -68,10 +47,7 @@ int main() {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Tensor Ops used in this example must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
if (!(props.major == 7 && props.minor >= 5)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
@ -80,6 +56,86 @@ int main() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run();
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "Running on SM75" << std::endl;
|
||||
pass &= run_nonfused_gemm_f16();
|
||||
pass &= run_fused_gemm_f16();
|
||||
pass &= run_nonfused_gemm_s8();
|
||||
pass &= run_fused_gemm_s8();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
|
||||
}
|
||||
|
||||
int run_sm80() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
notSupported = true;
|
||||
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == 8 && props.minor >= 0)) {
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "Running on SM80" << std::endl;
|
||||
pass &= run_nonfused_gemm_f16_sm80();
|
||||
pass &= run_fused_gemm_f16_sm80();
|
||||
pass &= run_nonfused_gemm_s8_sm80();
|
||||
pass &= run_fused_gemm_s8_sm80();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
int result = 0;
|
||||
|
||||
result = run_sm80();
|
||||
|
||||
if(!result) { // not supported
|
||||
result = run_sm75();
|
||||
|
||||
if(!result) {
|
||||
std::cout << "This example isn't supported on current architecture" << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(result >= 0)
|
||||
return 0;
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@ -66,6 +66,7 @@ struct B2bGemm {
|
||||
cutlass::gemm::GemmCoord problem_size_0;
|
||||
cutlass::gemm::GemmCoord problem_size_1;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename B2bMma::IteratorA0::Params params_A0;
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0;
|
||||
typename B2bMma::IteratorB0::Params params_B0;
|
||||
@ -91,7 +92,7 @@ struct B2bGemm {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -112,6 +113,7 @@ struct B2bGemm {
|
||||
problem_size_0(problem_size_0),
|
||||
problem_size_1(problem_size_1),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A0(ref_A0.layout()),
|
||||
ref_A0(ref_A0),
|
||||
params_B0(ref_B0.layout()),
|
||||
@ -211,7 +213,7 @@ struct B2bGemm {
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
@ -315,7 +317,7 @@ struct B2bGemm {
|
||||
//
|
||||
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
|
||||
@ -209,6 +209,7 @@ struct B2bImplicitGemmConvolution {
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
gemm::GemmCoord implicit_gemm_problem_size_0;
|
||||
gemm::GemmCoord implicit_gemm_problem_size_1;
|
||||
int swizzle_log_tile;
|
||||
int gemm_k_iterations_0;
|
||||
int gemm_k_iterations_1;
|
||||
typename B2bMma::IteratorA0::Params iterator_A0;
|
||||
@ -233,7 +234,7 @@ struct B2bImplicitGemmConvolution {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): gemm_k_iterations_0(0), gemm_k_iterations_1(0) { }
|
||||
Params(): swizzle_log_tile(0), gemm_k_iterations_0(0), gemm_k_iterations_1(0) { }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -245,7 +246,6 @@ struct B2bImplicitGemmConvolution {
|
||||
problem_size_1(args.problem_size_1),
|
||||
implicit_gemm_problem_size_0(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0)),
|
||||
implicit_gemm_problem_size_1(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1)),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
iterator_A0(B2bMma::IteratorA0::getParams(args.problem_size_0, args.ref_A0.layout())),
|
||||
ptr_A0(args.ref_A0.data()),
|
||||
iterator_B0(args.problem_size_0, args.ref_B0.layout()),
|
||||
@ -272,6 +272,8 @@ struct B2bImplicitGemmConvolution {
|
||||
implicit_gemm_problem_size_0,
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.problem_size_0.split_k_slices);
|
||||
|
||||
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
@ -296,7 +298,7 @@ struct B2bImplicitGemmConvolution {
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
|
||||
@ -379,7 +381,7 @@ struct B2bImplicitGemmConvolution {
|
||||
|
||||
// Compute logical position within grid
|
||||
threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
@ -111,9 +111,7 @@ template <
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Beta is zero or not
|
||||
bool IsBetaZero = false
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultB2bGemm;
|
||||
|
||||
@ -321,9 +319,7 @@ template <
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Is Beta zero or not
|
||||
bool IsBetaZero>
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
@ -332,7 +328,7 @@ struct DefaultB2bGemm<
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator, IsBetaZero> {
|
||||
SplitKSerial, Operator> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -353,8 +349,7 @@ struct DefaultB2bGemm<
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK,
|
||||
IsBetaZero>::Epilogue;
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
@ -397,9 +392,7 @@ template <
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Is Beta zero or not
|
||||
bool IsBetaZero>
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
@ -407,7 +400,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> {
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -426,8 +419,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK,
|
||||
IsBetaZero>::Epilogue;
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
|
||||
@ -43,14 +43,122 @@ fp32 data by using NVIDIA Ampere architecture.
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Result(
|
||||
double runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess
|
||||
):
|
||||
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
float alpha;
|
||||
float beta;
|
||||
|
||||
bool reference_check;
|
||||
int iterations;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({5120, 4096, 4096}),
|
||||
batch_count(1),
|
||||
reference_check(true),
|
||||
iterations(20),
|
||||
alpha(1),
|
||||
beta() { }
|
||||
|
||||
bool valid() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "14_ampere_tf32_tensorop_gemm example\n\n"
|
||||
<< " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha\n"
|
||||
<< " --beta <f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm --m=1024 --n=512 --k=1024 \\\n"
|
||||
<< " --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = problem_size.product() * batch_count;
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The code section below describes datatype for input, output matrices and computation between
|
||||
// elements in input matrices.
|
||||
using ElementAccumulator = float; // <- data type of accumulator
|
||||
@ -111,14 +219,10 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
int run() {
|
||||
|
||||
const int length_m = 5120;
|
||||
const int length_n = 4096;
|
||||
const int length_k = 4096;
|
||||
int run(Options &options) {
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
||||
cutlass::gemm::GemmCoord problem_size = options.problem_size;
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
@ -166,8 +270,8 @@ int run() {
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
@ -199,9 +303,74 @@ int run() {
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
// Result structure
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
//
|
||||
// Run profiling loop
|
||||
//
|
||||
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
cutlass::reference::device::Gemm<ElementInputA,
|
||||
@ -235,12 +404,17 @@ int run() {
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
if (passed) {
|
||||
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
int main(int argc, const char **argv) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
@ -272,5 +446,21 @@ int main() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run();
|
||||
Options options;
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
printf("%d x %d x %d TF32 tensor op Matrix Multiply\n", \
|
||||
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
|
||||
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return run(options);
|
||||
}
|
||||
|
||||
@ -152,7 +152,7 @@ int run() {
|
||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
// Same size as the above. The above one needs to be reordered and stored in this one.
|
||||
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
|
||||
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
|
||||
@ -158,7 +158,7 @@ using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSw
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
// This code section describe iterator algorithm selected is Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kAnalytic;
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
@ -189,7 +189,6 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
|
||||
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
@ -755,6 +754,3 @@ int main(int argc, char const **args) {
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
|
||||
28
examples/18_ampere_fp64_tensorop_affine2_gemm/CMakeLists.txt
Normal file
28
examples/18_ampere_fp64_tensorop_affine2_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
18_ampere_fp64_tensorop_affine2_gemm
|
||||
ampere_fp64_tensorop_affine2_gemm.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,336 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
In the normal GEMM, the fast changing dimension of a matrix always has stride
|
||||
equals to 1, e.g. ColumnMajor and RowMajor matrix. Affine2 matrix can have
|
||||
larger than 1 stride in both dimensions. To support such layout, we need to
|
||||
change to method to visit the global memory:
|
||||
|
||||
1. We can only visit 1 element a time because elements are not stored
|
||||
consecutively anymore. Vectorized load/store is not possible.
|
||||
2. One extra multiplication is needed in calculating the global memory
|
||||
address
|
||||
addr = base_pointer + coord1 * stride1 + coord2 * stride2
|
||||
|
||||
The rest part of GEMM which includes shared memory load/store, mma comutation
|
||||
is the same.
|
||||
|
||||
This example uses Ampere fp64 tensore core Affine2 GEMM as an example. SIMT
|
||||
(e.g. sgemm, dgemm) has support Affine2 layout.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h"
|
||||
#include "cutlass/reduction/device/reduce_split_k.h"
|
||||
#include "cutlass/reduction/kernel/reduce_split_k.h"
|
||||
#include "cutlass/reduction/thread/reduction_operators.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = double; // Data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation
|
||||
using ElementInputA = double; // Data type of elements in input tensor
|
||||
using ElementInputB = double; // Data type of elements in input tensor
|
||||
using ElementOutput = double; // Data type of elements in output tensor
|
||||
|
||||
// Since Affine2 explicitly lists the strides of both dimensions, it does not really matter if
|
||||
// it is columnmajor and rowmajor. However, it helps CUTLASS to improve the load locality if
|
||||
// CUTLASS can know which dimension of A/B operand has smaller stride or more dense.
|
||||
//
|
||||
// Affine2 ColumnMajor means the row stride is smaller and Affine2 RowMajor means the column
|
||||
// stride is smaller.
|
||||
//
|
||||
// The Affine2 epilogue reuses AffineN epilogue so it does not need to specify column majore
|
||||
// or row major.
|
||||
using LayoutInputA = cutlass::layout::AffineRank2ColumnMajor;
|
||||
using LayoutInputB = cutlass::layout::AffineRank2RowMajor;
|
||||
using LayoutOutput = cutlass::layout::AffineRankN<2>;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
1, // The number of elements per memory
|
||||
// access has. It has to be 1 for
|
||||
// affine2.
|
||||
ElementComputeEpilogue>;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
||||
ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 1, // AlignmentA has to be 1
|
||||
ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 1, // AlignmentB has to be 1
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd
|
||||
>::GemmKernel;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int run() {
|
||||
|
||||
// Construct Gemm ProblemSize with user defined output size
|
||||
cutlass::gemm::GemmCoord problem_size = {1024, 512, 1024};
|
||||
|
||||
// Stride factor shows the distance between two elements in the differnet dimensions. The
|
||||
// first data is the logical distance between two rows, the second is between two columns.
|
||||
// CUTLASS has a utility tool cutlass::layout::Affine2Layout_Factory<Layout>::layout_factory
|
||||
// to help to convert stride_factor to the two strides.
|
||||
//
|
||||
// It is also totally fine to compute the strides directly without using the utility to
|
||||
// construct the affine2 layout.
|
||||
typename LayoutInputA::Stride::Index stride_factor_A[] = {3, 4};
|
||||
typename LayoutInputB::Stride::Index stride_factor_B[] = {5, 6};
|
||||
typename LayoutOutput::Stride::Index stride_factor_C[] = {7, 8};
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk(),
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutInputA>::layout_factory(problem_size.mk(),
|
||||
stride_factor_A));
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn(),
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutInputB>::layout_factory(problem_size.kn(),
|
||||
stride_factor_B));
|
||||
|
||||
// Create matrix C used to load for bias addition.
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn(),
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutOutput>::layout_factory(problem_size.mn(),
|
||||
stride_factor_C));
|
||||
|
||||
// Create matrix D used to store output from CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn(),
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutOutput>::layout_factory(problem_size.mn(),
|
||||
stride_factor_C));
|
||||
|
||||
// Create matrix D with dimensions M x N used to store output from reference
|
||||
// kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(problem_size.mn(),
|
||||
cutlass::layout::Affine2Layout_Factory<LayoutOutput>::layout_factory(problem_size.mn(),
|
||||
stride_factor_C));
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(4),
|
||||
ElementInputA(-4),
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(4),
|
||||
ElementInputB(-4),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(4),
|
||||
ElementOutput(-4),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(1);
|
||||
|
||||
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
|
||||
|
||||
int batch_count = 1;
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{
|
||||
mode,
|
||||
problem_size,
|
||||
batch_count,
|
||||
{alpha, beta},
|
||||
tensor_a.device_ref().data(), // <- reference to matrix A on device
|
||||
tensor_b.device_ref().data(), // <- reference to matrix B on device
|
||||
tensor_c.device_ref().data(), // <- reference to matrix C on device
|
||||
tensor_d.device_ref().data(), // <- reference to matrix D on device
|
||||
tensor_a.layout().capacity(problem_size.mn()),
|
||||
tensor_b.layout().capacity(problem_size.kn()),
|
||||
tensor_c.layout().capacity(problem_size.mn()),
|
||||
tensor_d.layout().capacity(problem_size.mn()),
|
||||
tensor_a.layout().stride(),
|
||||
tensor_b.layout().stride(),
|
||||
tensor_c.layout().stride(),
|
||||
tensor_d.layout().stride()
|
||||
};
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
//
|
||||
// Create instantiation for device reference gemm kernel
|
||||
//
|
||||
|
||||
// Launch device reference to compute strictly the product A * B
|
||||
cutlass::reference::device::Gemm<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator> gemm_device;
|
||||
|
||||
gemm_device
|
||||
(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
beta,
|
||||
tensor_c.device_ref(),
|
||||
tensor_ref_d.device_ref()
|
||||
);
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d.sync_host();
|
||||
tensor_ref_d.sync_host();
|
||||
|
||||
bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
std::cout << (pass
|
||||
? "Passed"
|
||||
: "Failed")
|
||||
<< std::endl;
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
27
examples/19_tensorop_canonical/CMakeLists.txt
Normal file
27
examples/19_tensorop_canonical/CMakeLists.txt
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
19_tensorop_canonical
|
||||
tensorop_canonical.cu
|
||||
)
|
||||
|
||||
432
examples/19_tensorop_canonical/tensorop_canonical.cu
Normal file
432
examples/19_tensorop_canonical/tensorop_canonical.cu
Normal file
@ -0,0 +1,432 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*
|
||||
This example requires NVIDIA Ampere GPU or later.
|
||||
*/
|
||||
|
||||
// Standard Library includes
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
// CUTLASS Includes
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
||||
|
||||
// CUTLASS Utility Includes
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gemm_complex.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Define the overal warp-level problem shape
|
||||
int const kM = 27;
|
||||
int const kN = 31;
|
||||
int const kK = 17;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Define a warp-level GEMM operator.
|
||||
//
|
||||
// This template could be part of the CUTLASS Template Library or implemented internally. This
|
||||
// wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be
|
||||
// instantiated in device code.
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename InstructionShape,
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementScalar
|
||||
>
|
||||
class GemmTensorOp {
|
||||
public:
|
||||
|
||||
using WarpShape = GemmShape<
|
||||
((Shape::kM + InstructionShape::kM - 1) / InstructionShape::kM) * InstructionShape::kM,
|
||||
((Shape::kN + InstructionShape::kN - 1) / InstructionShape::kN) * InstructionShape::kN,
|
||||
InstructionShape::kK
|
||||
>;
|
||||
|
||||
using MmaWarp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
double, // Data type of A elements
|
||||
cutlass::layout::RowMajor, // Layout of A matrix
|
||||
double, // Data type of B elements
|
||||
cutlass::layout::ColumnMajor, // Layout of B matrix
|
||||
double, // Data type of C elements
|
||||
cutlass::layout::RowMajor // Layout of C matrix
|
||||
>::Type;
|
||||
|
||||
// Number of 'K groups'
|
||||
int const kKgroups = (Shape::kK + InstructionShape::kK - 1) / InstructionShape::kK;
|
||||
|
||||
// Define a 'FragmentIterator' to iterate over slices of accumulators
|
||||
using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
typename MmaWarp::Shape,
|
||||
InstructionShape,
|
||||
double,
|
||||
typename MmaWarp::Policy::Operator::FragmentC,
|
||||
cutlass::layout::RowMajor
|
||||
>;
|
||||
|
||||
// Define an epilogue 'Tile Iteterator' to iterate over slices of elements in Shared Memory
|
||||
using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpCanonical<
|
||||
typename MmaWarp::Shape,
|
||||
InstructionShape,
|
||||
double,
|
||||
cutlass::layout::RowMajor
|
||||
>;
|
||||
|
||||
using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
|
||||
using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
|
||||
using TensorRefC = typename AccumulatorTileIterator::TensorRef;
|
||||
|
||||
public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmTensorOp() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
ElementScalar alpha,
|
||||
TensorRefA ref_A,
|
||||
TensorRefB ref_B,
|
||||
ElementScalar beta,
|
||||
TensorRefC ref_C,
|
||||
TensorRefC ref_D,
|
||||
int lane_id) const {
|
||||
|
||||
// Instantiate iterators pointing to slices of the A and B matrices in shared memory
|
||||
typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id);
|
||||
typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id);
|
||||
|
||||
// Instantiate and clear accumulator tile holding the C matrix
|
||||
typename MmaWarp::FragmentC accum;
|
||||
accum.clear();
|
||||
|
||||
// Instantiate the warp-level matrix multiply operator
|
||||
MmaWarp mma_op;
|
||||
|
||||
// Instantiate fragments holding the slice of the matrix held by each warp
|
||||
typename MmaWarp::FragmentA frag_A[2];
|
||||
typename MmaWarp::FragmentB frag_B[2];
|
||||
|
||||
// Load fragments from shared memory
|
||||
iter_A.load(frag_A[0]);
|
||||
iter_B.load(frag_B[0]);
|
||||
|
||||
++iter_A;
|
||||
++iter_B;
|
||||
|
||||
// Load fragments from shared memory
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < kKgroups; ++k) {
|
||||
|
||||
// Load fragments from shared memory
|
||||
iter_A.load(frag_A[(k + 1) % 2]);
|
||||
iter_B.load(frag_B[(k + 1) % 2]);
|
||||
|
||||
++iter_A;
|
||||
++iter_B;
|
||||
|
||||
// Compute the matrix multiply
|
||||
mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum);
|
||||
}
|
||||
|
||||
// Instantiate iterators
|
||||
FragmentIterator accum_frag_it(accum);
|
||||
AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id);
|
||||
AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id);
|
||||
|
||||
// Define function objects for linear scaling operation
|
||||
cutlass::multiplies<typename FragmentIterator::Fragment> mul_source;
|
||||
cutlass::multiply_add<typename FragmentIterator::Fragment> mul_add_accumulator;
|
||||
|
||||
// Iterate over the epilogue components
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) {
|
||||
|
||||
// Define storage for slices of the accumulators
|
||||
typename FragmentIterator::Fragment accum_fragment;
|
||||
typename FragmentIterator::Fragment source_fragment;
|
||||
|
||||
// Select a slice of accumulators from the accumulator tile
|
||||
accum_frag_it.load(accum_fragment);
|
||||
++accum_frag_it;
|
||||
|
||||
// Load a corresponding slice from Shared memory
|
||||
source_tile_it.load(source_fragment);
|
||||
++source_tile_it;
|
||||
|
||||
// Compute linear scaling - alpha * AB + beta * C
|
||||
source_fragment = mul_source(beta, source_fragment);
|
||||
accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment);
|
||||
|
||||
// Store the result to shared memory
|
||||
dest_tile_it.store(accum_fragment);
|
||||
++dest_tile_it;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held
|
||||
// in Shared Memory.
|
||||
__global__ void kernel(
|
||||
double *D_gmem,
|
||||
double alpha,
|
||||
double const *A_gmem,
|
||||
double const *B_gmem,
|
||||
double beta,
|
||||
double const *C_gmem) {
|
||||
|
||||
// Define several matrices in shared memory
|
||||
__shared__ double A[kM][kK];
|
||||
__shared__ double B[kN][kK];
|
||||
__shared__ double C[kM][kN];
|
||||
|
||||
// Copy data into SMEM
|
||||
if (threadIdx.x == 0) {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
for (int k = 0; k < kK; ++k) {
|
||||
A[m][k] = A_gmem[m * kK + k];
|
||||
}
|
||||
}
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
for (int k = 0; k < kK; ++k) {
|
||||
B[n][k] = B_gmem[n * kK + k];
|
||||
}
|
||||
}
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
C[m][n] = C_gmem[m * kN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4),
|
||||
// overall shape, data type of each operand, and layout of each operand.
|
||||
//
|
||||
|
||||
using GemmTensorOp = cutlass::gemm::warp::GemmTensorOp<
|
||||
cutlass::gemm::GemmShape<kM, kN, kK>,
|
||||
cutlass::gemm::GemmShape<8, 8, 4>,
|
||||
double, // Data type of A elements
|
||||
cutlass::layout::RowMajor, // Layout of A matrix
|
||||
double, // Data type of B elements
|
||||
cutlass::layout::ColumnMajor, // Layout of B matrix
|
||||
double, // Data type of C elements
|
||||
cutlass::layout::RowMajor, // Layout of C matrix
|
||||
double // Scalar type of alpha and beta
|
||||
>;
|
||||
|
||||
// Instantiate the GEMM operator
|
||||
GemmTensorOp gemm;
|
||||
|
||||
// Execute the warp-level GEMM operation
|
||||
gemm(
|
||||
alpha,
|
||||
{&A[0][0], kK},
|
||||
{&B[0][0], kK},
|
||||
beta,
|
||||
{&C[0][0], kN},
|
||||
{&C[0][0], kN},
|
||||
threadIdx.x);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Copy data into SMEM
|
||||
if (threadIdx.x == 0) {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
D_gmem[m * kN + n] = C[m][n];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Entry point to canonical warp-level GEMM operation
|
||||
int main(int argc, const char *arg[]) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
std::cerr << "NVIDIA Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
std::cerr << "This example requires compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Return 0 so tests are considered passing if run on unsupported platforms.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cutlass::HostTensor<double, cutlass::layout::RowMajor> A({kM, kK});
|
||||
cutlass::HostTensor<double, cutlass::layout::ColumnMajor> B({kK, kN});
|
||||
cutlass::HostTensor<double, cutlass::layout::RowMajor> C({kM, kN});
|
||||
cutlass::HostTensor<double, cutlass::layout::RowMajor> D({kM, kN});
|
||||
|
||||
uint64_t seed = 2020;
|
||||
double max = 8;
|
||||
double min = -8;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
A.host_view(),
|
||||
seed,
|
||||
max,
|
||||
min,
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
B.host_view(),
|
||||
seed + 17,
|
||||
max,
|
||||
min,
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
C.host_view(),
|
||||
seed + 31,
|
||||
max,
|
||||
min,
|
||||
0
|
||||
);
|
||||
|
||||
A.sync_device();
|
||||
B.sync_device();
|
||||
C.sync_device();
|
||||
D.sync_device();
|
||||
|
||||
dim3 grid(1,1);
|
||||
dim3 block(32, 1, 1);
|
||||
|
||||
double alpha = 2.25;
|
||||
double beta = 1.24;
|
||||
|
||||
kernel<<< grid, block >>>(
|
||||
D.device_data(),
|
||||
alpha,
|
||||
A.device_data(),
|
||||
B.device_data(),
|
||||
beta,
|
||||
C.device_data()
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to synchronize device after kernel launch." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
D.sync_host();
|
||||
|
||||
// Compute reference on host
|
||||
cutlass::HostTensor<double, cutlass::layout::RowMajor> D_ref({kM, kN}, false);
|
||||
|
||||
cutlass::reference::host::GemmComplex(
|
||||
{kM, kN, kK},
|
||||
alpha,
|
||||
A.host_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
B.host_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
beta,
|
||||
C.host_ref(),
|
||||
D_ref.host_ref(),
|
||||
double()
|
||||
);
|
||||
|
||||
// Verify reference matches computed
|
||||
if (!cutlass::reference::host::TensorEquals(
|
||||
D.host_view(),
|
||||
D_ref.host_view())) {
|
||||
|
||||
std::cerr
|
||||
<< "A =\n" << A.host_view()
|
||||
<< "\n\nB = \n" << B.host_view()
|
||||
<< "\n\nC = " << C.host_view()
|
||||
<< "\n\nRef =\n" << D_ref.host_view()
|
||||
<< "\n\nD =\n" << D.host_view() << "\n\n";
|
||||
|
||||
std::cerr << "Error - device results mismatch host reference." << std::endl;
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout << "Passed" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
27
examples/20_simt_canonical/CMakeLists.txt
Normal file
27
examples/20_simt_canonical/CMakeLists.txt
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
20_simt_canonical
|
||||
simt_canonical.cu
|
||||
)
|
||||
|
||||
419
examples/20_simt_canonical/simt_canonical.cu
Normal file
419
examples/20_simt_canonical/simt_canonical.cu
Normal file
@ -0,0 +1,419 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*
|
||||
This example requires NVIDIA Maxwell GPU or beyond.
|
||||
*/
|
||||
|
||||
// Standard Library includes
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
// CUTLASS Includes
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/core_io.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/gemm/warp/mma_simt.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_simt.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_simt.h"
|
||||
|
||||
// CUTLASS Utility Includes
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/gemm_complex.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Define the overal warp-level problem shape
|
||||
int const kM = 14;
|
||||
int const kN = 27;
|
||||
int const kK = 17;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Define a warp-level GEMM operator.
|
||||
//
|
||||
// This template could be part of the CUTLASS Template Library or implemented internally. This
|
||||
// wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be
|
||||
// instantiated in device code.
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementScalar
|
||||
>
|
||||
class GemmSimt {
|
||||
public:
|
||||
|
||||
|
||||
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
|
||||
cutlass::MatrixShape<4, 8>,
|
||||
cutlass::layout::RowMajorInterleaved<2>,
|
||||
cutlass::gemm::GemmShape<4, 4, 1>
|
||||
>;
|
||||
|
||||
using MmaWarp = cutlass::gemm::warp::MmaSimt<
|
||||
cutlass::gemm::GemmShape<16, 32, 8>,
|
||||
float,
|
||||
cutlass::layout::RowMajor,
|
||||
float,
|
||||
cutlass::layout::ColumnMajor,
|
||||
float,
|
||||
cutlass::layout::RowMajor,
|
||||
Policy
|
||||
>;
|
||||
|
||||
// Number of 'K groups'
|
||||
int const kKgroups = Shape::kK;
|
||||
|
||||
using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
|
||||
typename MmaWarp::Shape,
|
||||
typename MmaWarp::ThreadMma,
|
||||
layout::RowMajor, // SMEM layout
|
||||
typename MmaWarp::Policy
|
||||
>;
|
||||
|
||||
using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorSimtCanonical<
|
||||
typename MmaWarp::Shape,
|
||||
typename MmaWarp::ThreadMma,
|
||||
float, // ElementAccumulator
|
||||
layout::RowMajor, // SMEM layout
|
||||
typename MmaWarp::Policy
|
||||
>;
|
||||
|
||||
using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
|
||||
using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
|
||||
using TensorRefC = typename AccumulatorTileIterator::TensorRef;
|
||||
|
||||
public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmSimt() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
ElementScalar alpha,
|
||||
TensorRefA ref_A,
|
||||
TensorRefB ref_B,
|
||||
ElementScalar beta,
|
||||
TensorRefC ref_C,
|
||||
TensorRefC ref_D,
|
||||
int lane_id) const {
|
||||
|
||||
// Instantiate iterators pointing to slices of the A and B matrices in shared memory
|
||||
typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id);
|
||||
typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id);
|
||||
|
||||
// Instantiate and clear accumulator tile holding the C matrix
|
||||
typename MmaWarp::FragmentC accum;
|
||||
accum.clear();
|
||||
|
||||
// Instantiate the warp-level matrix multiply operator
|
||||
MmaWarp mma_op;
|
||||
|
||||
// Instantiate fragments holding the slice of the matrix held by each warp
|
||||
typename MmaWarp::FragmentA frag_A[2];
|
||||
typename MmaWarp::FragmentB frag_B[2];
|
||||
|
||||
// Load fragments from shared memory
|
||||
iter_A.load(frag_A[0]);
|
||||
iter_B.load(frag_B[0]);
|
||||
|
||||
++iter_A;
|
||||
++iter_B;
|
||||
|
||||
// Load fragments from shared memory
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < kKgroups; ++k) {
|
||||
|
||||
// Load fragments from shared memory
|
||||
iter_A.load(frag_A[(k + 1) % 2]);
|
||||
iter_B.load(frag_B[(k + 1) % 2]);
|
||||
|
||||
++iter_A;
|
||||
++iter_B;
|
||||
|
||||
// Compute the matrix multiply
|
||||
mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum);
|
||||
}
|
||||
|
||||
// Instantiate iterators
|
||||
FragmentIterator accum_frag_it(accum);
|
||||
AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id);
|
||||
AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id);
|
||||
|
||||
// Define function objects for linear scaling operation
|
||||
cutlass::multiplies<typename FragmentIterator::Fragment> mul_source;
|
||||
cutlass::multiply_add<typename FragmentIterator::Fragment> mul_add_accumulator;
|
||||
|
||||
// Iterate over the epilogue components
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) {
|
||||
|
||||
// Define storage for slices of the accumulators
|
||||
typename FragmentIterator::Fragment accum_fragment;
|
||||
typename FragmentIterator::Fragment source_fragment;
|
||||
|
||||
// Select a slice of accumulators from the accumulator tile
|
||||
accum_frag_it.load(accum_fragment);
|
||||
++accum_frag_it;
|
||||
|
||||
// Load a corresponding slice from Shared memory
|
||||
source_tile_it.load(source_fragment);
|
||||
++source_tile_it;
|
||||
|
||||
// Compute linear scaling - alpha * AB + beta * C
|
||||
source_fragment = mul_source(beta, source_fragment);
|
||||
accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment);
|
||||
|
||||
// Store the result to shared memory
|
||||
dest_tile_it.store(accum_fragment);
|
||||
++dest_tile_it;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held
|
||||
// in Shared Memory.
|
||||
__global__ void kernel(
|
||||
float *D_gmem,
|
||||
float alpha,
|
||||
float const *A_gmem,
|
||||
float const *B_gmem,
|
||||
float beta,
|
||||
float const *C_gmem) {
|
||||
|
||||
// Define several matrices in shared memory
|
||||
__shared__ float A[kM][kK];
|
||||
__shared__ float B[kN][kK];
|
||||
__shared__ float C[kM][kN];
|
||||
|
||||
// Copy data into SMEM
|
||||
if (threadIdx.x == 0) {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
for (int k = 0; k < kK; ++k) {
|
||||
A[m][k] = A_gmem[m * kK + k];
|
||||
}
|
||||
}
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
for (int k = 0; k < kK; ++k) {
|
||||
B[n][k] = B_gmem[n * kK + k];
|
||||
}
|
||||
}
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
C[m][n] = C_gmem[m * kN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4),
|
||||
// overall shape, data type of each operand, and layout of each operand.
|
||||
//
|
||||
|
||||
using GemmSimt = cutlass::gemm::warp::GemmSimt<
|
||||
cutlass::gemm::GemmShape<kM, kN, kK>,
|
||||
float, // Data type of A elements
|
||||
cutlass::layout::RowMajor, // Layout of A matrix
|
||||
float, // Data type of B elements
|
||||
cutlass::layout::ColumnMajor, // Layout of B matrix
|
||||
float, // Data type of C elements
|
||||
cutlass::layout::RowMajor, // Layout of C matrix
|
||||
float // Scalar type of alpha and beta
|
||||
>;
|
||||
|
||||
// Instantiate the GEMM operator
|
||||
GemmSimt gemm;
|
||||
|
||||
// Execute the warp-level GEMM operation
|
||||
gemm(
|
||||
alpha,
|
||||
{&A[0][0], kK},
|
||||
{&B[0][0], kK},
|
||||
beta,
|
||||
{&C[0][0], kN},
|
||||
{&C[0][0], kN},
|
||||
threadIdx.x);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Copy data into SMEM
|
||||
if (threadIdx.x == 0) {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
D_gmem[m * kN + n] = C[m][n];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, const char *arg[]) {
|
||||
|
||||
cutlass::HostTensor<float, cutlass::layout::RowMajor> A({kM, kK});
|
||||
cutlass::HostTensor<float, cutlass::layout::ColumnMajor> B({kK, kN});
|
||||
cutlass::HostTensor<float, cutlass::layout::RowMajor> C({kM, kN});
|
||||
cutlass::HostTensor<float, cutlass::layout::RowMajor> D({kM, kN});
|
||||
|
||||
uint64_t seed = 2020;
|
||||
float max = 8;
|
||||
float min = -8;
|
||||
|
||||
std::cout << "Simt canonical GEMM problem size = (" << cutlass::gemm::GemmShape<kM, kN, kK>() <<")" << std::endl;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
A.host_view(),
|
||||
seed,
|
||||
max,
|
||||
min,
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
B.host_view(),
|
||||
seed + 17,
|
||||
max,
|
||||
min,
|
||||
0
|
||||
);
|
||||
|
||||
#if 0 // Debug: fill A sequentially and B as Identity matrix for debugging
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
A.host_view().data(), A.host_view().capacity());
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(B.host_view());
|
||||
#endif
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
C.host_view(),
|
||||
seed + 31,
|
||||
max,
|
||||
min,
|
||||
0
|
||||
);
|
||||
|
||||
A.sync_device();
|
||||
B.sync_device();
|
||||
C.sync_device();
|
||||
D.sync_device();
|
||||
|
||||
dim3 grid(1, 1);
|
||||
dim3 block(32, 1, 1);
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
|
||||
kernel<<< grid, block >>>(
|
||||
D.device_data(),
|
||||
alpha,
|
||||
A.device_data(),
|
||||
B.device_data(),
|
||||
beta,
|
||||
C.device_data()
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to synchronize device after kernel launch." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
D.sync_host();
|
||||
|
||||
// Compute reference on host
|
||||
cutlass::HostTensor<float, cutlass::layout::RowMajor> D_ref({kM, kN}, false);
|
||||
cutlass::reference::host::TensorCopy(D_ref.host_view(), C.host_view());
|
||||
|
||||
cutlass::reference::host::Gemm<
|
||||
float, cutlass::layout::RowMajor,
|
||||
float, cutlass::layout::ColumnMajor,
|
||||
float, cutlass::layout::RowMajor,
|
||||
float, float> reference_gemm;
|
||||
|
||||
reference_gemm(
|
||||
{kM, kN, kK},
|
||||
alpha,
|
||||
A.host_ref(),
|
||||
B.host_ref(),
|
||||
beta,
|
||||
D_ref.host_ref(),
|
||||
float()
|
||||
);
|
||||
|
||||
// Verify reference matches computed
|
||||
if (!cutlass::reference::host::TensorEquals(
|
||||
D.host_view(),
|
||||
D_ref.host_view())) {
|
||||
|
||||
std::cerr
|
||||
<< "A =\n" << A.host_view()
|
||||
<< "\n\nB = \n" << B.host_view()
|
||||
<< "\n\nC = " << C.host_view()
|
||||
<< "\n\nRef =\n" << D_ref.host_view()
|
||||
<< "\n\nD =\n" << D.host_view() << "\n\n";
|
||||
|
||||
std::cerr << "Error - device results mismatch host reference." << std::endl;
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout << "Passed" << std::endl;
|
||||
|
||||
return 0;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
27
examples/21_quaternion_gemm/CMakeLists.txt
Normal file
27
examples/21_quaternion_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
21_quaternion_gemm
|
||||
quaternion_gemm.cu
|
||||
)
|
||||
|
||||
448
examples/21_quaternion_gemm/quaternion_gemm.cu
Normal file
448
examples/21_quaternion_gemm/quaternion_gemm.cu
Normal file
@ -0,0 +1,448 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Result(
|
||||
double runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess
|
||||
):
|
||||
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
cutlass::Quaternion<float> alpha;
|
||||
cutlass::Quaternion<float> beta;
|
||||
|
||||
bool reference_check;
|
||||
int iterations;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({1024, 1024, 1024}),
|
||||
batch_count(1),
|
||||
reference_check(true),
|
||||
iterations(20),
|
||||
alpha(1),
|
||||
beta() { }
|
||||
|
||||
bool valid() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
cmd.get_cmd_line_argument("batch", batch_count);
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha.w());
|
||||
cmd.get_cmd_line_argument("alpha_i", alpha.x());
|
||||
cmd.get_cmd_line_argument("alpha_j", alpha.y());
|
||||
cmd.get_cmd_line_argument("alpha_k", alpha.z());
|
||||
|
||||
cmd.get_cmd_line_argument("beta", beta.w());
|
||||
cmd.get_cmd_line_argument("beta_i", beta.x());
|
||||
cmd.get_cmd_line_argument("beta_j", beta.y());
|
||||
cmd.get_cmd_line_argument("beta_k", beta.z());
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "21_quaternion_gemm example\n\n"
|
||||
<< " This example uses the CUTLASS Library to execute Quaternion GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --batch <int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i <f32> Epilogue scalar alpha_i (imaginary part)\n"
|
||||
<< " --alpha_j <f32> Epilogue scalar alpha_j (imaginary part)\n"
|
||||
<< " --alpha_k <f32> Epilogue scalar alpha_k (imaginary part)\n"
|
||||
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i <f32> Epilogue scalar beta_i (imaginary part)\n\n"
|
||||
<< " --beta_j <f32> Epilogue scalar beta_j (imaginary part)\n\n"
|
||||
<< " --beta_k <f32> Epilogue scalar beta_k (imaginary part)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/21_quaternion_gemm/21_quaternion_gemm --batch=7 --m=1024 --n=512 --k=1024 \\\n"
|
||||
<< " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = problem_size.product() * batch_count * 16;
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The code section below describes datatype for input, output matrices and computation between
|
||||
// elements in input matrices.
|
||||
using precision = float;
|
||||
using Element = cutlass::Quaternion<float>;
|
||||
using ElementComputeEpilogue = Element; // <- data type of epilogue operations
|
||||
using ElementAccumulator = Element; // <- data type of accumulator
|
||||
using ElementInputA = Element; // <- data type of elements in input matrix A
|
||||
using ElementInputB = Element; // <- data type of elements in input matrix B
|
||||
using ElementOutput = Element; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Column Major for
|
||||
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassSimt;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm50;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<64, 64, 4>; // <- threadblock tile M = 64, N = 64, K = 8
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 16, 4>; // <- warp tile M = 32, N = 16, K = 8
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<1, 1, 1>; // <- MMA Op tile M = 1, N = 1, K = 1
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- Defaults
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
|
||||
// memory access. For a byte, it's 16
|
||||
// elements. This becomes the vector width of
|
||||
// math instructions in the epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 2;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
int run(Options options) {
|
||||
|
||||
// PASS/FAIL status
|
||||
bool passed = true;
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size = options.problem_size;
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// reference kernel
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
4,
|
||||
-4,
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
4,
|
||||
-4,
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
4,
|
||||
-4,
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||
tensor_c.device_ref(), // <- reference to matrix C on device
|
||||
tensor_d.device_ref(), // <- reference to matrix D on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Result structure
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
//
|
||||
// Run profiling loop
|
||||
//
|
||||
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
if (options.reference_check) {
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
cutlass::reference::device::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue> gemm_device;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_device(problem_size,
|
||||
alpha,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
beta,
|
||||
tensor_c.device_ref(),
|
||||
tensor_ref_d.device_ref());
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d.sync_host();
|
||||
tensor_ref_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
passed &= cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
}
|
||||
|
||||
if (passed) {
|
||||
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main(int argc, char const** argv) {
|
||||
|
||||
Options options;
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
printf("%d x %d x %d Single Precision Quaternion Matrix Multiply\n", \
|
||||
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
|
||||
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return run(options);
|
||||
}
|
||||
|
||||
28
examples/22_quaternion_conv/CMakeLists.txt
Normal file
28
examples/22_quaternion_conv/CMakeLists.txt
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
22_quaternion_conv
|
||||
quaternion_conv.cu
|
||||
)
|
||||
|
||||
660
examples/22_quaternion_conv/quaternion_conv.cu
Normal file
660
examples/22_quaternion_conv/quaternion_conv.cu
Normal file
@ -0,0 +1,660 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using Element = cutlass::Quaternion<float>;
|
||||
using ElementAccumulator = Element; // Data type of accumulator
|
||||
using ElementComputeEpilogue = Element; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = Element; // Data type of elements in input tensor
|
||||
using ElementInputB = Element; // Data type of elements in input tensor
|
||||
using ElementOutput = Element; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassSimt;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm50;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; // SIMT instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 2;
|
||||
|
||||
// This code section describe iterator algorithm selected is Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
IteratorAlgorithm
|
||||
>::Kernel;
|
||||
|
||||
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::Tensor4DCoord input_size;
|
||||
cutlass::Tensor4DCoord filter_size;
|
||||
cutlass::Tensor4DCoord padding;
|
||||
cutlass::MatrixCoord conv_stride;
|
||||
cutlass::MatrixCoord dilation;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool benchmark;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
input_size(1, 32, 32, 32),
|
||||
filter_size(32, 3, 3, 32),
|
||||
padding(1, 1, 1, 1),
|
||||
conv_stride(1, 1),
|
||||
dilation(1, 1),
|
||||
reference_check(false),
|
||||
measure_performance(true),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 8;
|
||||
|
||||
if ((input_size.c() % kAlignment) ||
|
||||
(filter_size.n() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding.h() != filter_size.h() / 2) ||
|
||||
(padding.w() != filter_size.w() / 2)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor4DCoord input_size,
|
||||
cutlass::Tensor4DCoord filter_size) {
|
||||
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
|
||||
padding.n() = filter_size.h() / 2;
|
||||
padding.h() = filter_size.h() / 2;
|
||||
padding.w() = filter_size.w() / 2;
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
filter_size.c() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha_w", alpha.w());
|
||||
cmd.get_cmd_line_argument("alpha_x", alpha.x());
|
||||
cmd.get_cmd_line_argument("alpha_y", alpha.y());
|
||||
cmd.get_cmd_line_argument("alpha_z", alpha.z());
|
||||
|
||||
cmd.get_cmd_line_argument("beta_w", beta.w());
|
||||
cmd.get_cmd_line_argument("beta_x", beta.x());
|
||||
cmd.get_cmd_line_argument("beta_y", beta.y());
|
||||
cmd.get_cmd_line_argument("beta_z", beta.z());
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
if (filter_size.h() == 3 && filter_size.w() == 3) {
|
||||
padding = {1, 1, 1, 1};
|
||||
}
|
||||
else {
|
||||
filter_size.h() = 1;
|
||||
filter_size.w() = 1;
|
||||
padding = {0, 0, 0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "22_quaternion_conv example\n\n"
|
||||
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n <int> Input tensor extent N\n"
|
||||
<< " --h <int> Input tensor extent H\n"
|
||||
<< " --w <int> Input tensor extent W\n"
|
||||
<< " --c <int> Input tensor extent C\n"
|
||||
<< " --k <int> Filter extent K\n"
|
||||
<< " --r <int> Filter extent R\n"
|
||||
<< " --s <int> Filter extent S\n\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/22_quaternion_conv/22_quaternion_conv --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
<< "$ ./examples/22_quaternion_conv/22_quaternion_conv --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
input_size.n(),
|
||||
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
||||
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()) * 16;
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.h() << ","
|
||||
<< options.input_size.w() << ","
|
||||
<< options.input_size.c() << ","
|
||||
<< options.filter_size.n() << ","
|
||||
<< options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_c(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
7,
|
||||
-8,
|
||||
0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
7,
|
||||
-8,
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
|
||||
// Fill tensor C for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_ref_c.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Construct Conv2dProblemSize with user defined output size
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
// Construct ImplicitGemm::Argument structure with conv2d
|
||||
// problem size, data pointers, and epilogue values
|
||||
typename ImplicitGemm::Arguments arguments{
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
ImplicitGemm implicit_gemm_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
result.status = implicit_gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on host...\n";
|
||||
|
||||
// Compute with reference implementation
|
||||
cutlass::reference::host::Conv2dFprop<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
tensor_a.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_c.host_ref(),
|
||||
options.alpha,
|
||||
options.beta
|
||||
);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_c.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_c.host_view(),
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
if (options.save_workspace) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "22_quaternion_conv_"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
<< "Input = \n" << tensor_a.host_view() << "\n\n"
|
||||
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.benchmark) {
|
||||
// Benchmark several layers
|
||||
|
||||
int batch_sizes[] = {1, 32, 64, 128, 256, 512};
|
||||
|
||||
struct Benchmark {
|
||||
int h, w, c, k, r, s;
|
||||
} layers[] = {
|
||||
{56, 56, 64, 256, 1, 1},
|
||||
{56, 56, 64, 64, 1, 1},
|
||||
{56, 56, 64, 64, 3, 3},
|
||||
{56, 56, 256, 64, 1, 1},
|
||||
{56, 56, 256, 512, 1, 1},
|
||||
{56, 56, 256, 128, 1, 1},
|
||||
{28, 28, 128, 128, 3, 3},
|
||||
{28, 28, 128, 512, 1, 1},
|
||||
{28, 28, 512, 128, 1, 1},
|
||||
{28, 28, 512, 1024, 1, 1},
|
||||
{28, 28, 512, 256, 1, 1},
|
||||
{14, 14, 256, 256, 3, 3},
|
||||
{14, 14, 256, 1024, 1, 1},
|
||||
{14, 14, 1024, 256, 1, 1},
|
||||
{14, 14, 1024, 2048, 1, 1},
|
||||
{14, 14, 1024, 512, 1, 1},
|
||||
{7, 7, 512, 512, 3, 3},
|
||||
};
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
|
||||
int idx = 1;
|
||||
|
||||
for (auto const &layer : layers) {
|
||||
for (auto N : batch_sizes) {
|
||||
|
||||
options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c});
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
result.print(std::cout, idx, options) << std::endl;
|
||||
}
|
||||
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -28,10 +28,14 @@ add_custom_target(test_examples)
|
||||
function(cutlass_example_add_executable NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(oneValueArgs DISABLE_TESTS)
|
||||
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT DEFINED __DISABLE_TESTS)
|
||||
set(__DISABLE_TESTS OFF)
|
||||
endif()
|
||||
|
||||
cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
add_dependencies(cutlass_examples ${NAME})
|
||||
@ -60,6 +64,7 @@ function(cutlass_example_add_executable NAME)
|
||||
DEPENDEES test_examples ${__DEPENDEES}
|
||||
TEST_COMMAND_OPTIONS ${__TEST_COMMAND_OPTIONS}
|
||||
DISABLE_EXECUTABLE_INSTALL_RULE
|
||||
DISABLE_TESTS ${__DISABLE_TESTS}
|
||||
)
|
||||
|
||||
endfunction()
|
||||
@ -83,6 +88,11 @@ foreach(EXAMPLE
|
||||
15_ampere_sparse_tensorop_gemm
|
||||
16_ampere_tensorop_conv2dfprop
|
||||
17_fprop_per_channel_bias
|
||||
18_ampere_fp64_tensorop_affine2_gemm
|
||||
19_tensorop_canonical
|
||||
20_simt_canonical
|
||||
21_quaternion_gemm
|
||||
22_quaternion_conv
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -33,6 +33,26 @@
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
||||
|
||||
/// Computes laneId within a warp
|
||||
CUTLASS_DEVICE
|
||||
int LaneId() {
|
||||
int ret;
|
||||
asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : );
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Computes SM number the thread is running on
|
||||
CUTLASS_DEVICE
|
||||
int SmId() {
|
||||
int ret;
|
||||
asm ("mov.u32 %0, %%smid;" : "=r"(ret) : );
|
||||
return ret;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct Sm50 {
|
||||
static int const kMinComputeCapability = 50;
|
||||
|
||||
@ -51,10 +51,20 @@ struct global_load;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
|
||||
(__CUDACC_VER_MAJOR__ > 11)) && \
|
||||
defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \
|
||||
! (defined(__clang__) && defined(__CUDA__))
|
||||
#define CUTLASS_ENABLE_L2_PREFETCH 1
|
||||
#else
|
||||
#define CUTLASS_ENABLE_L2_PREFETCH 0
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The redundant mov PTX instruction is used to enforce the compiler to
|
||||
// initialize data to zero before ld.global
|
||||
template <typename AccessType
|
||||
>
|
||||
template <typename AccessType>
|
||||
struct global_load<AccessType,
|
||||
32
|
||||
> {
|
||||
@ -62,55 +72,61 @@ struct global_load<AccessType,
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
uint4 *data = reinterpret_cast<uint4 *>(&D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %9, 0;\n"
|
||||
" mov.b32 %0, %10;\n"
|
||||
" mov.b32 %1, %11;\n"
|
||||
" mov.b32 %2, %12;\n"
|
||||
" mov.b32 %3, %13;\n"
|
||||
" mov.b32 %4, %14;\n"
|
||||
" mov.b32 %5, %15;\n"
|
||||
" mov.b32 %6, %16;\n"
|
||||
" mov.b32 %7, %17;\n"
|
||||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n"
|
||||
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n"
|
||||
"}\n"
|
||||
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w),
|
||||
"=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y),
|
||||
"r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y),
|
||||
"r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %9, 0;\n"
|
||||
" mov.b32 %0, %10;\n"
|
||||
" mov.b32 %1, %11;\n"
|
||||
" mov.b32 %2, %12;\n"
|
||||
" mov.b32 %3, %13;\n"
|
||||
" mov.b32 %4, %14;\n"
|
||||
" mov.b32 %5, %15;\n"
|
||||
" mov.b32 %6, %16;\n"
|
||||
" mov.b32 %7, %17;\n"
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
" @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%8];\n"
|
||||
" @p ld.global.L2::128B.v4.u32 {%4, %5, %6, %7}, [%18];\n"
|
||||
#else
|
||||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n"
|
||||
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n"
|
||||
#endif
|
||||
"}\n"
|
||||
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w),
|
||||
"=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y),
|
||||
"r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y),
|
||||
"r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
template <typename AccessType>
|
||||
struct global_load<AccessType,
|
||||
16
|
||||
> {
|
||||
CUTLASS_DEVICE
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
uint4 &data = reinterpret_cast<uint4 &>(D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %5, 0;\n"
|
||||
" mov.b32 %0, %6;\n"
|
||||
" mov.b32 %1, %7;\n"
|
||||
" mov.b32 %2, %8;\n"
|
||||
" mov.b32 %3, %9;\n"
|
||||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
|
||||
"}\n"
|
||||
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %5, 0;\n"
|
||||
" mov.b32 %0, %6;\n"
|
||||
" mov.b32 %1, %7;\n"
|
||||
" mov.b32 %2, %8;\n"
|
||||
" mov.b32 %3, %9;\n"
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
" @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%4];\n"
|
||||
#else
|
||||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
|
||||
#endif
|
||||
"}\n"
|
||||
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
template <typename AccessType>
|
||||
struct global_load<AccessType,
|
||||
8
|
||||
> {
|
||||
@ -118,21 +134,24 @@ struct global_load<AccessType,
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
uint2 &data = reinterpret_cast<uint2 &>(D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %3, 0;\n"
|
||||
" mov.b32 %0, %4;\n"
|
||||
" mov.b32 %1, %5;\n"
|
||||
" @p ld.global.v2.u32 {%0, %1}, [%2];\n"
|
||||
"}\n"
|
||||
: "=r"(data.x), "=r"(data.y)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %3, 0;\n"
|
||||
" mov.b32 %0, %4;\n"
|
||||
" mov.b32 %1, %5;\n"
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
" @p ld.global.L2::128B.v2.u32 {%0, %1}, [%2];\n"
|
||||
#else
|
||||
" @p ld.global.v2.u32 {%0, %1}, [%2];\n"
|
||||
#endif
|
||||
"}\n"
|
||||
: "=r"(data.x), "=r"(data.y)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
template <typename AccessType>
|
||||
struct global_load<AccessType,
|
||||
4
|
||||
> {
|
||||
@ -140,20 +159,23 @@ struct global_load<AccessType,
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
unsigned &data = reinterpret_cast<unsigned &>(D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %2, 0;\n"
|
||||
" mov.b32 %0, %3;\n"
|
||||
" @p ld.global.u32 %0, [%1];\n"
|
||||
"}\n"
|
||||
: "=r"(data)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %2, 0;\n"
|
||||
" mov.b32 %0, %3;\n"
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
" @p ld.global.L2::128B.u32 %0, [%1];\n"
|
||||
#else
|
||||
" @p ld.global.u32 %0, [%1];\n"
|
||||
#endif
|
||||
"}\n"
|
||||
: "=r"(data)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
template <typename AccessType>
|
||||
struct global_load<AccessType,
|
||||
2
|
||||
> {
|
||||
@ -161,20 +183,23 @@ struct global_load<AccessType,
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
uint16_t &data = reinterpret_cast<uint16_t &>(D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %2, 0;\n"
|
||||
" mov.b16 %0, %3;\n"
|
||||
" @p ld.global.u16 %0, [%1];\n"
|
||||
"}\n"
|
||||
: "=h"(data)
|
||||
: "l"(ptr), "r"((int)pred_guard), "h"(data));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %2, 0;\n"
|
||||
" mov.b16 %0, %3;\n"
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
" @p ld.global.L2::128B.u16 %0, [%1];\n"
|
||||
#else
|
||||
" @p ld.global.u16 %0, [%1];\n"
|
||||
#endif
|
||||
"}\n"
|
||||
: "=h"(data)
|
||||
: "l"(ptr), "r"((int)pred_guard), "h"(data));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
template <typename AccessType>
|
||||
struct global_load<AccessType,
|
||||
1
|
||||
> {
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
|
||||
@ -90,7 +91,11 @@ struct cp_async<SizeInBytes, CacheOperation::Always> {
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
" @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n"
|
||||
#else
|
||||
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
|
||||
#endif
|
||||
"}\n" ::"r"((int)pred_guard),
|
||||
"r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
|
||||
|
||||
@ -123,7 +128,11 @@ struct cp_async_zfill<SizeInBytes, CacheOperation::Always> {
|
||||
int src_in_bytes = (pred_guard ? SizeInBytes : 0);
|
||||
|
||||
asm volatile(
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
|
||||
#else
|
||||
"cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
|
||||
#endif
|
||||
"l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes));
|
||||
|
||||
#else
|
||||
@ -163,7 +172,11 @@ struct cp_async<SizeInBytes, CacheOperation::Global> {
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
" @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n"
|
||||
#else
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
#endif
|
||||
"}\n" ::"r"((int)pred_guard),
|
||||
"r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
|
||||
|
||||
@ -195,7 +208,11 @@ struct cp_async_zfill<SizeInBytes, CacheOperation::Global> {
|
||||
int src_in_bytes = (pred_guard ? SizeInBytes : 0);
|
||||
|
||||
asm volatile(
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
|
||||
#else
|
||||
"cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
|
||||
#endif
|
||||
"l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes));
|
||||
|
||||
#else
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
@ -130,11 +131,12 @@ template <
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Inner product operator
|
||||
typename Operator
|
||||
typename Operator_
|
||||
>
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator> {
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator_> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = Operator_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -144,7 +146,9 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, El
|
||||
Array<ElementC, 1> const &c
|
||||
) {
|
||||
|
||||
d[0] = a[0] * b[0] + c[0];
|
||||
multiply_add<ElementA, ElementB, ElementC> op;
|
||||
|
||||
d[0] = op(a[0], b[0], c[0]);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -30,6 +30,8 @@
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/quaternion.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
@ -379,5 +381,35 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float,
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Matrix multiply-add operation for Quaternions
|
||||
template <
|
||||
/// Layout of A matrix
|
||||
typename LayoutA,
|
||||
/// Layout of B matrix
|
||||
typename LayoutB,
|
||||
/// Layout of C matrix
|
||||
typename LayoutC
|
||||
>
|
||||
struct Mma<gemm::GemmShape<1, 1, 1>, 1, Quaternion<float>, LayoutA, Quaternion<float>, LayoutB, Quaternion<float>, LayoutC, OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<1, 1, 1>;
|
||||
using Operator = OpMultiplyAdd;
|
||||
using Element = Quaternion<float>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
Array<Element, 1> &d,
|
||||
Array<Element, 1> const &a,
|
||||
Array<Element, 1> const &b,
|
||||
Array<Element, 1> const &c
|
||||
) {
|
||||
multiply_add<Element, Element, Element> op;
|
||||
d[0] = op(a[0], b[0], c[0]);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
#pragma once
|
||||
|
||||
// CUTLASS WMMA does not support clang at present.
|
||||
#if !defined(__clang__)
|
||||
#if !(defined(__clang__) && defined(__CUDA__))
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 9)
|
||||
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700))
|
||||
@ -52,7 +52,7 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif //!defined(__clang__)
|
||||
#endif //!(defined(__clang__) && defined(__CUDA__))
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ class Array;
|
||||
template <typename T, int N, bool RegisterSized>
|
||||
struct sizeof_bits<Array<T, N, RegisterSized> > {
|
||||
static int const value =
|
||||
sizeof(typename Array<T, N, RegisterSized>::Storage) * 8 * Array<T, N, RegisterSized>::kStorageElements;
|
||||
int(sizeof(typename Array<T, N, RegisterSized>::Storage)) * 8 * int(Array<T, N, RegisterSized>::kStorageElements);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -62,7 +62,7 @@ public:
|
||||
using Element = T;
|
||||
|
||||
/// Number of logical elements per stored object
|
||||
static int const kElementsPerStoredItem = (sizeof(Storage) * 8) / sizeof_bits<T>::value;
|
||||
static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits<T>::value;
|
||||
|
||||
/// Number of storage elements
|
||||
static size_t const kStorageElements = N / kElementsPerStoredItem;
|
||||
|
||||
@ -33,6 +33,7 @@
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -76,7 +77,13 @@ struct alignas(2) bfloat16_t {
|
||||
asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x));
|
||||
|
||||
#else
|
||||
uint32_t bits = reinterpret_cast<uint32_t &>(x);
|
||||
uint32_t bits;
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
bits = reinterpret_cast<uint32_t &>(x);
|
||||
#else
|
||||
std::memcpy(&bits, &x, sizeof(bits));
|
||||
#endif
|
||||
|
||||
if ((bits & 0x7f800000) != 0x7f800000) {
|
||||
|
||||
@ -106,14 +113,28 @@ struct alignas(2) bfloat16_t {
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit bfloat16_t(int x) {
|
||||
float flt = static_cast<float>(x);
|
||||
storage = uint16_t(reinterpret_cast<uint32_t const &>(flt) >> 16);
|
||||
uint32_t bits;
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
bits = reinterpret_cast<uint32_t &>(flt);
|
||||
#else
|
||||
std::memcpy(&bits, &flt, sizeof(bits));
|
||||
#endif
|
||||
|
||||
storage = uint16_t(bits >> 16);
|
||||
}
|
||||
|
||||
/// Converts to float
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator float() const {
|
||||
unsigned bits = (unsigned(storage) << 16);
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<float const &>(bits);
|
||||
#else
|
||||
float flt;
|
||||
std::memcpy(&flt, &bits, sizeof(flt));
|
||||
return flt;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Converts to float
|
||||
@ -237,11 +258,22 @@ cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) {
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) {
|
||||
|
||||
uint16_t a_mag = (reinterpret_cast<uint16_t const &>(a) & 0x7fff);
|
||||
uint16_t b_sign = (reinterpret_cast<uint16_t const &>(b) & 0x8000);
|
||||
uint16_t a_bits;
|
||||
uint16_t b_bits;
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
a_bits = reinterpret_cast<uint16_t const &>(a);
|
||||
b_bits = reinterpret_cast<uint16_t const &>(b);
|
||||
#else
|
||||
std::memcpy(&a_bits, &a, sizeof(a_bits));
|
||||
std::memcpy(&b_bits, &b, sizeof(b_bits));
|
||||
#endif
|
||||
|
||||
uint16_t a_mag = (a_bits & 0x7fff);
|
||||
uint16_t b_sign = (b_bits & 0x8000);
|
||||
uint16_t result = (a_mag | b_sign);
|
||||
|
||||
return reinterpret_cast<bfloat16_t const &>(result);
|
||||
return bfloat16_t::bitcast(result);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -38,6 +38,8 @@
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/tfloat32.h"
|
||||
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <iosfwd>
|
||||
#endif
|
||||
@ -442,16 +444,16 @@ CUTLASS_HOST_DEVICE complex<T> polar(T const &r, T const &theta = T()) {
|
||||
/// Computes the complex exponential of z.
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> exp(complex<T> const &z) {
|
||||
return complex<T>(real(z) * cos(imag(z)), real(z) * sin(imag(z)));
|
||||
return complex<T>(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z)));
|
||||
}
|
||||
|
||||
/// Computes the complex exponential of z.
|
||||
/// Computes the log of z
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> log(complex<T> const &z) {
|
||||
return complex<T>(log(abs(z)), arg(z));
|
||||
}
|
||||
|
||||
/// Computes the complex exponential of z.
|
||||
/// Computes the log base 10 of z
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> log10(complex<T> const &z) {
|
||||
return log(z) / T(log(T(10)));
|
||||
@ -484,6 +486,9 @@ template <typename T>
|
||||
struct RealType< complex<T> > {
|
||||
using Type = T;
|
||||
|
||||
/// Number of elements
|
||||
static int const kExtent = 2;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static complex<T> from_real(double x) {
|
||||
return complex<T>(static_cast<T>(x));
|
||||
|
||||
@ -284,6 +284,27 @@ public:
|
||||
|
||||
return cutlass::MatrixCoord ({dilation_h, dilation_w});
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
// Methods used for strided dgrad implementation
|
||||
/////////////////////////////////////////////////////////////////
|
||||
/// Number of filter r positions to accumulate in gemm-k dim
|
||||
CUTLASS_HOST_DEVICE
|
||||
int num_gemm_k_filter_r(int r) const {
|
||||
return ((R - r + stride_h - 1) / stride_h);
|
||||
}
|
||||
|
||||
/// Number of filter s positions to accumulate in gemm-k dim
|
||||
CUTLASS_HOST_DEVICE
|
||||
int num_gemm_k_filter_s(int s) const {
|
||||
return ((S - s + stride_w - 1) / stride_w);
|
||||
}
|
||||
|
||||
/// Number of filter positions to accumulate in gemm-k dim
|
||||
CUTLASS_HOST_DEVICE
|
||||
int num_gemm_k_filter_positions(int r, int s) const {
|
||||
return num_gemm_k_filter_r(r) * num_gemm_k_filter_s(s);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -444,6 +465,27 @@ int64_t implicit_gemm_tensor_c_size(
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Strided dgrad helper functions //
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Returns number of CTAs tile M to cover valid MMAs per starting filter postion
|
||||
CUTLASS_HOST_DEVICE
|
||||
int strided_dgrad_tile_m_per_filter(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
int tile_size_m) {
|
||||
|
||||
// Compute NHW rows in Dx output that needs MMA per starting filter position
|
||||
int rows_h_per_filter = (problem_size.H + problem_size.stride_h - 1) / problem_size.stride_h;
|
||||
int rows_w_per_filter = (problem_size.W + problem_size.stride_w - 1) / problem_size.stride_w;
|
||||
int rows_nhw_per_filter = problem_size.N * rows_h_per_filter * rows_w_per_filter;
|
||||
|
||||
// Number of CTAs tile M to cover valid MMAs per starting filter postion
|
||||
int tile_m_per_filter = (rows_nhw_per_filter + tile_size_m - 1) / tile_size_m;
|
||||
|
||||
return tile_m_per_filter;
|
||||
}
|
||||
|
||||
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
@ -115,4 +115,3 @@ enum class SplitKMode {
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -71,6 +71,7 @@ public:
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmKernel::kConvolutionalOperator;
|
||||
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmKernel::kIteratorAlgorithm;
|
||||
static cutlass::conv::StrideSupport const kStrideSupport = ImplicitGemmKernel::kStrideSupport;
|
||||
|
||||
static int const kWarpCount =
|
||||
(ThreadblockShape::kM / WarpShape::kM) *
|
||||
@ -104,12 +105,37 @@ public:
|
||||
return status;
|
||||
}
|
||||
|
||||
// check for unsupported problem sizes for strided dgrad implementation
|
||||
if (kConvolutionalOperator == conv::Operator::kDgrad &&
|
||||
kStrideSupport == conv::StrideSupport::kStrided) {
|
||||
|
||||
// Unity stride (1x1) is supported by strided dgrad but disabled for performance
|
||||
// reasons. For unity stride, use strided dgrad optimized unity stride specialization.
|
||||
// Note that unit tests strided dgrad for unity stride to make sure that strided
|
||||
// dgrad implemetnation is functionaly sound.
|
||||
// Strided dgrad implementation also support mixed strides, i.e., (1x2) and (2x1)
|
||||
if(args.problem_size.stride_h == 1 && args.problem_size.stride_w == 1) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
// split-k (serial or parallel) is not supported for strided dgrad
|
||||
if(args.problem_size.split_k_slices > 1) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
// dilation > {1x1} is not supported for strided dgrad
|
||||
if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(
|
||||
threadblock_swizzle.get_tiled_shape(
|
||||
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
|
||||
kConvolutionalOperator,
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices));
|
||||
|
||||
@ -131,7 +157,8 @@ public:
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
|
||||
kConvolutionalOperator,
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices);
|
||||
|
||||
@ -220,6 +247,7 @@ public:
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
|
||||
@ -33,6 +33,7 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/conv/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
@ -41,6 +42,9 @@
|
||||
#include "cutlass/conv/threadblock/implicit_gemm_pipelined.h"
|
||||
#include "cutlass/conv/threadblock/implicit_gemm_multistage.h"
|
||||
#include "cutlass/conv/kernel/implicit_gemm_convolution.h"
|
||||
#include "cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h"
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -62,7 +66,7 @@ struct DefaultConvEpilogue {
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
OutputOp::kCount
|
||||
>::Epilogue;
|
||||
@ -85,7 +89,49 @@ struct DefaultConvEpilogue<
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
OutputOp::kCount
|
||||
>::Epilogue;
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Defaults for strided Dgrad
|
||||
template <
|
||||
typename ArchTag,
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename OutputOp
|
||||
>
|
||||
struct DefaultConvEpilogueStridedDgrad {
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
OutputOp::kCount
|
||||
>::Epilogue;
|
||||
};
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename OutputOp
|
||||
>
|
||||
struct DefaultConvEpilogueStridedDgrad<
|
||||
arch::Sm70,
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp
|
||||
> {
|
||||
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOpStridedDgrad<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
OutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
@ -35,7 +35,7 @@
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_tile_iterator.h"
|
||||
@ -83,7 +83,6 @@ template <
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
@ -101,7 +100,7 @@ struct DefaultConv2dDgrad <
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
@ -117,7 +116,7 @@ struct DefaultConv2dDgrad <
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
@ -138,7 +137,8 @@ struct DefaultConv2dDgrad <
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
StrideSupport::kStrided
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -160,17 +160,19 @@ struct DefaultConv2dDgrad <
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
@ -188,7 +190,6 @@ template <
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
@ -205,7 +206,7 @@ struct DefaultConv2dDgrad <
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
@ -221,13 +222,13 @@ struct DefaultConv2dDgrad <
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
@ -241,11 +242,12 @@ struct DefaultConv2dDgrad <
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
StrideSupport::kStrided
|
||||
>
|
||||
>;
|
||||
|
||||
@ -267,17 +269,19 @@ struct DefaultConv2dDgrad <
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
@ -297,7 +301,6 @@ template <
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
@ -315,7 +318,7 @@ struct DefaultConv2dDgrad <
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
@ -331,7 +334,7 @@ struct DefaultConv2dDgrad <
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
@ -352,7 +355,8 @@ struct DefaultConv2dDgrad <
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -374,11 +378,13 @@ struct DefaultConv2dDgrad <
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
@ -402,7 +408,6 @@ template <
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
@ -419,7 +424,7 @@ struct DefaultConv2dDgrad <
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
@ -435,7 +440,7 @@ struct DefaultConv2dDgrad <
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
@ -459,7 +464,8 @@ struct DefaultConv2dDgrad <
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity
|
||||
>
|
||||
>;
|
||||
|
||||
@ -481,12 +487,14 @@ struct DefaultConv2dDgrad <
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
@ -511,7 +519,6 @@ template <
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
@ -529,7 +536,7 @@ struct DefaultConv2dDgrad <
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
@ -545,7 +552,7 @@ struct DefaultConv2dDgrad <
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
@ -588,11 +595,13 @@ struct DefaultConv2dDgrad <
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
@ -616,7 +625,6 @@ template <
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
@ -633,7 +641,7 @@ struct DefaultConv2dDgrad <
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
@ -649,7 +657,7 @@ struct DefaultConv2dDgrad <
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
@ -695,12 +703,14 @@ struct DefaultConv2dDgrad <
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
@ -734,8 +744,7 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
typename MathOperatorTag>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
@ -754,7 +763,7 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kStrided
|
||||
conv::StrideSupport::kUnity
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -770,7 +779,7 @@ struct DefaultConv2dDgrad <
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kStrided
|
||||
conv::StrideSupport::kUnity
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
@ -781,7 +790,8 @@ struct DefaultConv2dDgrad <
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
conv::StrideSupport::kUnity
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -823,6 +833,110 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport::kStrided
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
conv::StrideSupport::kStrided
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
conv::StrideSupport::kStrided
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm,
|
||||
/// multi-stage pipeline, and FFMA-based mainloop for SM80
|
||||
|
||||
@ -888,7 +1002,8 @@ struct DefaultConv2dDgrad <
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -928,6 +1043,8 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Analytic IteratorAlgorithm,
|
||||
@ -966,7 +1083,7 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kStrided
|
||||
conv::StrideSupport::kUnity
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -983,7 +1100,7 @@ struct DefaultConv2dDgrad <
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kStrided
|
||||
conv::StrideSupport::kUnity
|
||||
>
|
||||
>;
|
||||
|
||||
@ -996,7 +1113,8 @@ struct DefaultConv2dDgrad <
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
conv::StrideSupport::kUnity
|
||||
>
|
||||
>;
|
||||
|
||||
@ -1034,6 +1152,112 @@ struct DefaultConv2dDgrad <
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport::kStrided
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
conv::StrideSupport::kStrided
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
conv::StrideSupport::kStrided
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1104,7 +1328,8 @@ struct DefaultConv2dDgrad <
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity
|
||||
>
|
||||
>;
|
||||
|
||||
@ -1144,8 +1369,6 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -157,11 +157,13 @@ struct DefaultConv2dFprop <
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
@ -271,11 +273,13 @@ struct DefaultConv2dFprop <
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount,
|
||||
InterleavedK
|
||||
@ -378,12 +382,14 @@ struct DefaultConv2dFprop <
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
@ -494,11 +500,13 @@ struct DefaultConv2dFprop <
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount,
|
||||
InterleavedK
|
||||
@ -602,11 +610,13 @@ struct DefaultConv2dFprop <
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
@ -708,11 +718,13 @@ struct DefaultConv2dFprop <
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount,
|
||||
InterleavedK
|
||||
@ -817,12 +829,14 @@ struct DefaultConv2dFprop <
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
@ -923,11 +937,13 @@ struct DefaultConv2dFprop <
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount,
|
||||
InterleavedK
|
||||
|
||||
@ -0,0 +1,117 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Defines a GEMM with Reduction based on an existing UniversalGemm kernel.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
>
|
||||
struct DefaultConv2dFpropWithBroadcast {
|
||||
|
||||
using ImplicitGemmBase = typename DefaultConv2dFprop<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm,
|
||||
StrideSupport
|
||||
>::Kernel;
|
||||
|
||||
// Replace epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp<
|
||||
typename ImplicitGemmBase::Epilogue::Shape,
|
||||
typename ImplicitGemmBase::Epilogue::WarpMmaOperator,
|
||||
ImplicitGemmBase::Epilogue::kPartitionsK,
|
||||
ElementC,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC,
|
||||
EpilogueOutputOp,
|
||||
ImplicitGemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue<
|
||||
typename ImplicitGemmBase::Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,117 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Defines a GEMM with Reduction based on an existing UniversalGemm kernel.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename EpilogueReductionOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
>
|
||||
struct DefaultConv2dFpropWithReduction {
|
||||
|
||||
using ImplicitGemmBase = typename DefaultConv2dFprop<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm,
|
||||
StrideSupport
|
||||
>::Kernel;
|
||||
|
||||
// Replace epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp<
|
||||
typename ImplicitGemmBase::Epilogue::Shape,
|
||||
typename ImplicitGemmBase::Epilogue::WarpMmaOperator,
|
||||
ImplicitGemmBase::Epilogue::kPartitionsK,
|
||||
ElementC,
|
||||
EpilogueOutputOp,
|
||||
EpilogueReductionOp,
|
||||
ImplicitGemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue<
|
||||
typename ImplicitGemmBase::Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -160,11 +160,13 @@ struct DefaultConv2dWgrad <
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
@ -266,12 +268,14 @@ struct DefaultConv2dWgrad <
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
@ -371,11 +375,13 @@ struct DefaultConv2dWgrad <
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
@ -477,12 +483,14 @@ struct DefaultConv2dWgrad <
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
|
||||
@ -92,7 +92,8 @@ struct ImplicitGemmConvolution {
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm;
|
||||
|
||||
static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
@ -188,6 +189,8 @@ struct ImplicitGemmConvolution {
|
||||
ConvProblemSize problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
gemm::GemmCoord implicit_gemm_problem_size;
|
||||
int swizzle_log_tile;
|
||||
|
||||
int gemm_k_iterations;
|
||||
typename Mma::IteratorA::Params iterator_A;
|
||||
typename Mma::IteratorA::Element const *ptr_A;
|
||||
@ -206,7 +209,7 @@ struct ImplicitGemmConvolution {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): gemm_k_iterations(0) { }
|
||||
Params(): swizzle_log_tile(0), gemm_k_iterations(0) { }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -236,6 +239,8 @@ struct ImplicitGemmConvolution {
|
||||
implicit_gemm_problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices);
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
@ -260,7 +265,7 @@ struct ImplicitGemmConvolution {
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
|
||||
@ -327,7 +332,7 @@ struct ImplicitGemmConvolution {
|
||||
|
||||
// Compute logical position within grid
|
||||
threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
@ -0,0 +1,461 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined Implicit GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
#include "cutlass/epilogue/threadblock/output_iterator_parameter.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
|
||||
typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem
|
||||
>
|
||||
struct ImplicitGemmConvolutionStridedDgrad {
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static Operator const kConvolutionalOperator = ConvOperator;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename EpilogueOutputOp::ElementOutput;
|
||||
|
||||
/// Set output tensor C layout
|
||||
using LayoutC = LayoutA;
|
||||
|
||||
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
|
||||
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using WarpMmaOperator = typename Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename WarpMmaOperator::Shape;
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm;
|
||||
static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using TensorRefA = typename Mma::IteratorA::TensorRef;
|
||||
using TensorRefB = typename Mma::IteratorB::TensorRef;
|
||||
using TensorRefC = cutlass::TensorRef<ElementC, LayoutC>;
|
||||
|
||||
/// Check iterator A and B convolution dimension are the same and
|
||||
// set device::ImplicitGemmConvolution::kConvDim
|
||||
static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim,
|
||||
"Convolution on different different dimensions is not supported");
|
||||
static int const kConvDim = Mma::IteratorA::kConvDim;
|
||||
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
/// Wgrad C stride idx for implicit gemm algorithm
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
static int const kWgradCStrideIdx =
|
||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
|
||||
/// This chooses the appropriate stride element of the C tensor.
|
||||
static int const kTensorCStrideIdx =
|
||||
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
|
||||
|
||||
// Strided dgrad uses a specialized threadblock swizzle for functionality and performance
|
||||
static_assert((std::is_same<ThreadblockSwizzle,
|
||||
threadblock::StridedDgradHorizontalThreadblockSwizzle>::value) ||
|
||||
(std::is_same<ThreadblockSwizzle,
|
||||
threadblock::StridedDgradIdentityThreadblockSwizzle<1>>::value) ||
|
||||
(std::is_same<ThreadblockSwizzle,
|
||||
threadblock::StridedDgradIdentityThreadblockSwizzle<4>>::value) ||
|
||||
(std::is_same<ThreadblockSwizzle,
|
||||
threadblock::StridedDgradIdentityThreadblockSwizzle<8>>::value),
|
||||
"Needs ThreadblockSwizzle type specialized for strided dgrad");
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter<
|
||||
LayoutC,
|
||||
typename Epilogue::OutputTileIterator::Layout,
|
||||
TensorRefC,
|
||||
ConvOperator,
|
||||
ConvProblemSize
|
||||
>;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ConvProblemSize problem_size;
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size
|
||||
):
|
||||
problem_size(problem_size) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size,
|
||||
TensorRefA const & ref_A,
|
||||
TensorRefB const & ref_B,
|
||||
TensorRefC const & ref_C,
|
||||
TensorRefC const & ref_D,
|
||||
typename EpilogueOutputOp::Params const & output_op,
|
||||
SplitKMode const & split_k_mode = SplitKMode::kSerial
|
||||
):
|
||||
problem_size(problem_size),
|
||||
ref_A(ref_A),
|
||||
ref_B(ref_B),
|
||||
ref_C(ref_C),
|
||||
ref_D(ref_D),
|
||||
output_op(output_op),
|
||||
split_k_mode(split_k_mode)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
ConvProblemSize problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
FastDivmod filter_s_divmod;
|
||||
int gemm_k_iterations;
|
||||
typename Mma::IteratorA::Params iterator_A;
|
||||
typename Mma::IteratorA::Element const *ptr_A;
|
||||
typename Mma::IteratorB::Params iterator_B;
|
||||
typename Mma::IteratorB::Element const *ptr_B;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_C;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_C;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_D;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
int *semaphore;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): gemm_k_iterations(0) { }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
int *semaphore = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
filter_s_divmod(args.problem_size.stride_w),
|
||||
iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
|
||||
ptr_A(args.ref_A.data()),
|
||||
iterator_B(args.problem_size, args.ref_B.layout()),
|
||||
ptr_B(args.ref_B.data()),
|
||||
iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size, ThreadblockShape::kM),
|
||||
ptr_C(args.ref_C.data()),
|
||||
iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size, ThreadblockShape::kM),
|
||||
ptr_D(args.ref_D.data()),
|
||||
output_op(args.output_op),
|
||||
semaphore(semaphore),
|
||||
split_k_mode(args.split_k_mode)
|
||||
{
|
||||
gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
kConvolutionalOperator,
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ImplicitGemmConvolutionStridedDgrad() { }
|
||||
|
||||
/// Executes one ImplicitGEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Compute starting filter position for strided dgrad
|
||||
int tile_m_per_filter = strided_dgrad_tile_m_per_filter(params.problem_size,
|
||||
ThreadblockShape::kM);
|
||||
int filter_tile_m = (threadblock_tile_idx.m() / tile_m_per_filter);
|
||||
|
||||
|
||||
// The subsequent fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
// int start_r = filter_tile_m / (params.problem_size.stride_w);
|
||||
// int start_s = filter_tile_m % (params.problem_size.stride_w);
|
||||
|
||||
int start_r, start_s;
|
||||
params.filter_s_divmod(start_r, start_s, filter_tile_m);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
// Check if CTA contributes valid MMA (Dy * w) and accumulator will be non-zero after MMA
|
||||
if (start_r < params.problem_size.R && start_s < params.problem_size.S) {
|
||||
// Scale gemm_k_iterations for strided dgrad
|
||||
int gemm_k_iterations = (params.gemm_k_iterations / (params.problem_size.R * params.problem_size.S)
|
||||
) * params.problem_size.num_gemm_k_filter_positions(start_r, start_s);
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.iterator_A,
|
||||
params.problem_size,
|
||||
params.ptr_A,
|
||||
thread_idx,
|
||||
start_r, start_s,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
threadblock_tile_idx.k() * Mma::Shape::kK
|
||||
)
|
||||
);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.iterator_B,
|
||||
params.problem_size,
|
||||
params.ptr_B,
|
||||
thread_idx,
|
||||
start_r, start_s,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.k() * Mma::Shape::kK,
|
||||
threadblock_tile_idx.n() * Mma::Shape::kN
|
||||
)
|
||||
);
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
// Construct the semaphore.
|
||||
int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m();
|
||||
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// Compute logical position within grid
|
||||
threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
threadblock_tile_idx.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.iterator_D,
|
||||
params.ptr_D,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
start_r, start_s,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator reading from source accumulator tensor
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.iterator_C,
|
||||
params.ptr_C,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
start_r, start_s,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_idx.k()) {
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_idx.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
// Each split-k-slice writes to a unique tensor location
|
||||
else if (params.split_k_mode == SplitKMode::kParallel) {
|
||||
iterator_D.add_pointer_offset(threadblock_tile_idx.k() *
|
||||
cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size));
|
||||
}
|
||||
|
||||
// Run efficient epilogue
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_idx.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,493 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined Implicit GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
#include "cutlass/epilogue/threadblock/output_iterator_parameter.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
|
||||
typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem
|
||||
>
|
||||
struct ImplicitGemmConvolutionWithFusedEpilogue {
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static Operator const kConvolutionalOperator = ConvOperator;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename EpilogueOutputOp::ElementOutput;
|
||||
|
||||
/// Set output tensor C layout
|
||||
using LayoutC = LayoutA;
|
||||
|
||||
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
|
||||
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using WarpMmaOperator = typename Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename WarpMmaOperator::Shape;
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm;
|
||||
static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using TensorRefA = typename Mma::IteratorA::TensorRef;
|
||||
using TensorRefB = typename Mma::IteratorB::TensorRef;
|
||||
using TensorRefC = cutlass::TensorRef<ElementC, LayoutC>;
|
||||
|
||||
/// Check iterator A and B convolution dimension are the same and
|
||||
// set device::ImplicitGemmConvolution::kConvDim
|
||||
static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim,
|
||||
"Convolution on different different dimensions is not supported");
|
||||
static int const kConvDim = Mma::IteratorA::kConvDim;
|
||||
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
/// Wgrad C stride idx for implicit gemm algorithm
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
static int const kWgradCStrideIdx =
|
||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
|
||||
/// This chooses the appropriate stride element of the C tensor.
|
||||
static int const kTensorCStrideIdx =
|
||||
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter<
|
||||
LayoutC,
|
||||
typename Epilogue::OutputTileIterator::Layout,
|
||||
TensorRefC,
|
||||
ConvOperator,
|
||||
ConvProblemSize
|
||||
>;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ConvProblemSize problem_size;
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
void * ptr_Vector;
|
||||
void * ptr_Tensor;
|
||||
|
||||
typename LayoutC::Stride::Index ldr;
|
||||
typename LayoutC::Stride::Index ldt;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size
|
||||
):
|
||||
problem_size(problem_size) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size,
|
||||
TensorRefA const & ref_A,
|
||||
TensorRefB const & ref_B,
|
||||
TensorRefC const & ref_C,
|
||||
TensorRefC const & ref_D,
|
||||
typename EpilogueOutputOp::Params const & output_op,
|
||||
SplitKMode const & split_k_mode = SplitKMode::kSerial,
|
||||
void * ptr_Vector = nullptr,
|
||||
void * ptr_Tensor = nullptr,
|
||||
typename LayoutC::Stride::Index ldr = 0,
|
||||
typename LayoutC::Stride::Index ldt = 0
|
||||
):
|
||||
problem_size(problem_size),
|
||||
ref_A(ref_A),
|
||||
ref_B(ref_B),
|
||||
ref_C(ref_C),
|
||||
ref_D(ref_D),
|
||||
output_op(output_op),
|
||||
split_k_mode(split_k_mode),
|
||||
ptr_Vector(ptr_Vector),
|
||||
ptr_Tensor(ptr_Tensor),
|
||||
ldr(ldr),
|
||||
ldt(ldt)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
ConvProblemSize problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
gemm::GemmCoord implicit_gemm_problem_size;
|
||||
int swizzle_log_tile;
|
||||
|
||||
int gemm_k_iterations;
|
||||
typename Mma::IteratorA::Params iterator_A;
|
||||
typename Mma::IteratorA::Element const *ptr_A;
|
||||
typename Mma::IteratorB::Params iterator_B;
|
||||
typename Mma::IteratorB::Element const *ptr_B;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_C;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_C;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_D;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
int *semaphore;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
typename Epilogue::TensorTileIterator::Params params_Tensor;
|
||||
void * ptr_Vector;
|
||||
typename LayoutC::Stride::Index ldr;
|
||||
void * ptr_Tensor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
swizzle_log_tile(0),
|
||||
gemm_k_iterations(0),
|
||||
ptr_Vector(nullptr),
|
||||
ldr(0),
|
||||
ptr_Tensor(nullptr)
|
||||
{ }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
int *semaphore = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)),
|
||||
iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
|
||||
ptr_A(args.ref_A.data()),
|
||||
iterator_B(args.problem_size, args.ref_B.layout()),
|
||||
ptr_B(args.ref_B.data()),
|
||||
iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)),
|
||||
ptr_C(args.ref_C.data()),
|
||||
iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)),
|
||||
ptr_D(args.ref_D.data()),
|
||||
output_op(args.output_op),
|
||||
semaphore(semaphore),
|
||||
split_k_mode(args.split_k_mode),
|
||||
params_Tensor(args.ldt),
|
||||
ptr_Vector(args.ptr_Vector),
|
||||
ldr(args.ldr),
|
||||
ptr_Tensor(args.ptr_Tensor)
|
||||
|
||||
{
|
||||
gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
implicit_gemm_problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices);
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ImplicitGemmConvolutionWithFusedEpilogue() { }
|
||||
|
||||
/// Executes one ImplicitGEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.iterator_A,
|
||||
params.problem_size,
|
||||
params.ptr_A,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
threadblock_tile_idx.k() * Mma::Shape::kK
|
||||
)
|
||||
);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.iterator_B,
|
||||
params.problem_size,
|
||||
params.ptr_B,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.k() * Mma::Shape::kK,
|
||||
threadblock_tile_idx.n() * Mma::Shape::kN
|
||||
)
|
||||
);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
// Construct the semaphore.
|
||||
int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m();
|
||||
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// Compute logical position within grid
|
||||
threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
threadblock_tile_idx.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.iterator_D,
|
||||
params.ptr_D,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator reading from source accumulator tensor
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.iterator_C,
|
||||
params.ptr_C,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
typename Epilogue::ElementTensor *ptr_Tensor =
|
||||
static_cast<typename Epilogue::ElementTensor *>(params.ptr_Tensor);
|
||||
|
||||
// Define the reduction output pointer and move to the appropriate place
|
||||
typename Epilogue::ElementVector *ptr_Vector =
|
||||
static_cast<typename Epilogue::ElementVector *>(params.ptr_Vector);
|
||||
|
||||
// Additional tensor to load from
|
||||
typename Epilogue::TensorTileIterator tensor_iterator(
|
||||
params.params_Tensor,
|
||||
// Only the final block outputs Tensor
|
||||
((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) &&
|
||||
(params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1))
|
||||
? nullptr
|
||||
: ptr_Tensor,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
threadblock_offset);
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Move to appropriate location for this output tile
|
||||
if (ptr_Vector) {
|
||||
ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr;
|
||||
}
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_idx.k()) {
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_idx.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
// Each split-k-slice writes to a unique tensor location
|
||||
else if (params.split_k_mode == SplitKMode::kParallel) {
|
||||
iterator_D.add_pointer_offset(threadblock_tile_idx.k() *
|
||||
cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size));
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op,
|
||||
// Only the final block uses Vector
|
||||
((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) &&
|
||||
(params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1))
|
||||
? nullptr
|
||||
: ptr_Vector,
|
||||
iterator_D,
|
||||
accumulators,
|
||||
iterator_C,
|
||||
tensor_iterator,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
threadblock_offset);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_idx.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -55,12 +55,29 @@ namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv2dDgradFilterTileAccessIteratorAnalytic strided dgrad needs special handling to skip MMAs
|
||||
// on non-contributing w positions
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic {
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kStrided
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
@ -90,6 +107,197 @@ public:
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
// For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int start_r_;
|
||||
int start_s_;
|
||||
int offset_k_[ThreadMap::Iterations::kStrided];
|
||||
int offset_c_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_r_(start_r),
|
||||
filter_s_(start_s),
|
||||
start_r_(start_r),
|
||||
start_s_(start_s) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous()
|
||||
+ c * ThreadMap::Delta::kContiguous;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_k_[s] =
|
||||
threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// Moves filter_s
|
||||
filter_s_ += problem_size_.stride_w;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
// Restore filter_s
|
||||
filter_s_ = start_s_;
|
||||
|
||||
// Move filter_r
|
||||
filter_r_ += problem_size_.stride_h;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
// Restore filter_r
|
||||
filter_r_ = start_r_;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_k_[s] += Shape::kRow * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the filter tensor w that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int c = offset_c_[iteration_contiguous_];
|
||||
int k = offset_k_[iteration_strided_];
|
||||
|
||||
return TensorCoord(k, filter_r_, filter_s_, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the filter tensor w
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv2dDgradFilterTileAccessIteratorAnalytic unity strided dgrad is more performant for dgrad
|
||||
// on problem sizes with stride = {1x1}
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
>{
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or larger.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
|
||||
@ -62,7 +62,23 @@ template <
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorOptimized {
|
||||
class Conv2dDgradFilterTileAccessIteratorOptimized;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad
|
||||
// on problem sizes with stride = {1x1}
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorOptimized <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
@ -79,7 +95,7 @@ public:
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = StrideSupport_;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
@ -109,7 +110,7 @@ public:
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
using Params = Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams;
|
||||
|
||||
private:
|
||||
|
||||
@ -122,36 +123,13 @@ private:
|
||||
int filter_k_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int start_r_;
|
||||
int start_s_;
|
||||
|
||||
int offset_n_[ThreadMap::Iterations::kStrided];
|
||||
int offset_w_[ThreadMap::Iterations::kStrided];
|
||||
int offset_h_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
private:
|
||||
int offset_p_[ThreadMap::Iterations::kStrided];
|
||||
int offset_q_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
/// Returns the coordinate in the output tensor Dy that is currently pointed to
|
||||
/// by the iterator but DOES NOT scale by the convolution stride. This is needed
|
||||
/// to compute predicates in the valid() method. The return value of the public at()
|
||||
/// method is correctly scaled.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord unscaled_at_() const {
|
||||
int n = offset_n_[iteration_strided_];
|
||||
int h = offset_h_[iteration_strided_];
|
||||
int w = offset_w_[iteration_strided_];
|
||||
|
||||
int r = filter_r_;
|
||||
int s = filter_s_;
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r = (problem_size_.R - 1 - r);
|
||||
s = (problem_size_.S - 1 - s);
|
||||
}
|
||||
|
||||
int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h);
|
||||
int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w);
|
||||
|
||||
return TensorCoord(n, p, q, filter_k_);
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
@ -161,34 +139,68 @@ public:
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_k_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
filter_k_(0),
|
||||
filter_r_(start_r),
|
||||
filter_s_(start_s),
|
||||
start_r_(start_r),
|
||||
start_s_(start_s) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
int filter_r = filter_r_;
|
||||
int filter_s = filter_s_;
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
filter_r = (problem_size_.R - 1 - filter_r);
|
||||
filter_s = (problem_size_.S - 1 - filter_s);
|
||||
}
|
||||
|
||||
// Starting h, w positions for filter position in gemm_k=0
|
||||
int start_h = std::abs((problem_size_.pad_h - filter_r) % problem_size_.stride_h);
|
||||
int start_w = std::abs((problem_size_.pad_w - filter_s) % problem_size_.stride_w);
|
||||
|
||||
|
||||
// Effective P and Q for filter position required for remapping NHW rows
|
||||
int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h;
|
||||
int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w;
|
||||
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter;
|
||||
|
||||
offset_n_[s] = offset_nhw / (problem_size_.H * problem_size_.W);
|
||||
int residual = offset_nhw % (problem_size_.H * problem_size_.W);
|
||||
// (STEP 1) [reorder NHW rows to start with same filter positions]
|
||||
offset_n_[s] = offset_npq / (P * Q);
|
||||
int residual = offset_npq % (P * Q);
|
||||
|
||||
offset_h_[s] = residual / problem_size_.W;
|
||||
offset_w_[s] = residual % problem_size_.W;
|
||||
int p = (residual / Q);
|
||||
int q = (residual % Q);
|
||||
|
||||
int mapped_h = (start_h + p * problem_size_.stride_h);
|
||||
int mapped_w = (start_w + q * problem_size_.stride_w);
|
||||
|
||||
// Access (p, q) coordinates for Dy tensor and a filter position in gemm_k=0
|
||||
// note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are divisible
|
||||
// by stride_h and stride_w
|
||||
offset_p_[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h;
|
||||
offset_q_[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) {
|
||||
return Params(problem_size, layout);
|
||||
return Params(problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn});
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
@ -206,18 +218,26 @@ public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// move to the next tile
|
||||
++filter_s_;
|
||||
|
||||
// Move filter_s by stride_w
|
||||
filter_s_ += problem_size_.stride_w;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
|
||||
// Restore filter_s
|
||||
filter_s_ = start_s_;
|
||||
|
||||
// Move filter_r by stride_h
|
||||
filter_r_ += problem_size_.stride_h;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
|
||||
// Restore filter_r
|
||||
filter_r_ = start_r_;
|
||||
|
||||
// Move filter_k
|
||||
filter_k_ += Shape_::kColumn * problem_size_.split_k_slices;
|
||||
}
|
||||
|
||||
@ -225,14 +245,20 @@ public:
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
int n = offset_n_[iteration_strided_];
|
||||
int p = offset_p_[iteration_strided_];
|
||||
int q = offset_q_[iteration_strided_];
|
||||
|
||||
int conv_sign = (problem_size_.mode == Mode::kConvolution ? 1 : -1);
|
||||
|
||||
TensorCoord coord = unscaled_at_();
|
||||
p += (conv_sign * (filter_r_ / problem_size_.stride_h));
|
||||
q += (conv_sign * (filter_s_ / problem_size_.stride_w));
|
||||
|
||||
return TensorCoord(
|
||||
coord.n(),
|
||||
coord.h() / problem_size_.stride_h,
|
||||
coord.w() / problem_size_.stride_w,
|
||||
coord.c());
|
||||
n,
|
||||
p,
|
||||
q,
|
||||
filter_k_);
|
||||
}
|
||||
|
||||
|
||||
@ -240,11 +266,9 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord unscaled_coord = unscaled_at_();
|
||||
TensorCoord coord = at();
|
||||
|
||||
return
|
||||
!(unscaled_coord.h() % problem_size_.stride_h) && !(unscaled_coord.w() % problem_size_.stride_w) &&
|
||||
coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.P &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.Q &&
|
||||
|
||||
@ -32,6 +32,7 @@
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -62,11 +63,26 @@ template <
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized {
|
||||
public:
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static_assert(StrideSupport_ == conv::StrideSupport::kUnity,
|
||||
"Only unit-stride dgrad is supported at this time.");
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is optimized for dgrad
|
||||
// with problem stride = {1x1}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
@ -417,5 +433,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -99,7 +99,7 @@ public:
|
||||
|
||||
private:
|
||||
|
||||
Conv2dFpropActivationIteratorOptimizedParams<Layout> const ¶ms_;
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
@ -118,7 +118,7 @@ public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationTileAccessIteratorOptimized(
|
||||
Conv2dFpropActivationIteratorOptimizedParams<Layout> const ¶ms,
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
|
||||
@ -77,6 +77,38 @@ struct Conv2dAnalyticParams {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters structure used for Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams
|
||||
struct Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
int tiled_rows_per_filter;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout, ///< layout object
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape
|
||||
): layout(layout) {
|
||||
|
||||
int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row());
|
||||
|
||||
tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -199,6 +231,32 @@ struct Conv2dFpropActivationIteratorOptimizedParams<layout::TensorNHWC> {
|
||||
// logical offset added to internal channel counter - units are elements, not bytes
|
||||
filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices;
|
||||
}
|
||||
|
||||
#if 0
|
||||
/// Prints internal state.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void print() {
|
||||
auto stride = layout.stride();
|
||||
printf(
|
||||
"Conv2dFpropActivationIteratorOptimizedParams:\n"
|
||||
" layout(w: %d, h: %d, n: %d)\n"
|
||||
" inc_next[%ld, %ld, %ld]\n"
|
||||
" filter_c_delta(%d) - PQ(%d)\n"
|
||||
" pq_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n"
|
||||
" q_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n",
|
||||
stride[0], stride[1], stride[2],
|
||||
inc_next[0], inc_next[1], inc_next[2],
|
||||
filter_c_delta,
|
||||
PQ,
|
||||
pq_divmod.divisor,
|
||||
pq_divmod.multiplier,
|
||||
pq_divmod.shift_right,
|
||||
q_divmod.divisor,
|
||||
q_divmod.multiplier,
|
||||
q_divmod.shift_right
|
||||
);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized
|
||||
@ -324,6 +382,23 @@ struct Conv2dFpropFilterIteratorOptimizedParams<layout::TensorNHWC>
|
||||
|
||||
filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices;
|
||||
}
|
||||
|
||||
#if 0
|
||||
/// Prints internal state.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void print() {
|
||||
auto stride = layout.stride();
|
||||
printf(
|
||||
"Conv2dFpropFilterIteratorOptimizedParams:\n"
|
||||
" layout[%d, %d, %d]\n"
|
||||
" RS(%d), filter_c_delta(%d), inc_next(k: %ld, rs: %ld, c: %ld)\n",
|
||||
stride[0], stride[1], stride[2],
|
||||
RS,
|
||||
filter_c_delta,
|
||||
inc_next_k, inc_next_rs, inc_next_c
|
||||
);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template<int Interleaved_>
|
||||
@ -382,6 +457,9 @@ struct Conv2dFpropFilterIteratorOptimizedParams<layout::TensorCxRSKx<Interleaved
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Dgrad Optimized Dy params (layout::TensorNHWC)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Parameters object for Conv2d DGRAD OutputGradient (dy) iterator
|
||||
struct Conv2dDgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
@ -449,7 +527,9 @@ struct Conv2dDgradOutputGradientIteratorOptimizedParams {
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters object for Conv2d DGRAD Filter (w) iterator
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Dgrad Optimized w params (layout::TensorNHWC)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct Conv2dDgradFilterIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
@ -609,6 +689,25 @@ struct Conv2dWgradActivationIteratorOptimizedParams {
|
||||
}
|
||||
};
|
||||
|
||||
struct PredicatedScaleBiasVectorAccessIteratorParams {
|
||||
public:
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIteratorParams() { }
|
||||
|
||||
// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIteratorParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
layout::PitchLinear const &layout) {}
|
||||
|
||||
// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIteratorParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
layout::RowMajor const &layout) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
|
||||
@ -166,6 +166,125 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Strided Dgrad Tile Iterator
|
||||
template <typename TileAccessIterator_>
|
||||
class TileIteratorStridedDgrad {
|
||||
public:
|
||||
using TileAccessIterator = TileAccessIterator_;
|
||||
|
||||
using Shape = typename TileAccessIterator::Shape;
|
||||
using Element = typename TileAccessIterator::Element;
|
||||
using Layout = typename TileAccessIterator::Layout;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = typename TileAccessIterator::ThreadMap;
|
||||
using AccessType = typename TileAccessIterator::AccessType;
|
||||
using TensorRef = typename TileAccessIterator::TensorRef;
|
||||
using Index = typename TileAccessIterator::Index;
|
||||
using LongIndex = typename TileAccessIterator::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm;
|
||||
static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport;
|
||||
using Params = typename TileAccessIterator::Params;
|
||||
static int const kConvDim = TileAccessIterator::kConvDim;
|
||||
using ConvProblemSize = typename TileAccessIterator::ConvProblemSize;
|
||||
|
||||
/// Fragment object to be loaded or stored
|
||||
using Fragment = cutlass::Array<
|
||||
Element,
|
||||
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
|
||||
/// Internal state
|
||||
TileAccessIterator tile_access_iterator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorStridedDgrad(
|
||||
Params const ¶ms,
|
||||
ConvProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
tile_access_iterator_(params, problem_size, ptr, thread_idx, start_r, start_s, threadblock_offset) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) {
|
||||
return TileAccessIterator::getParams(problem_size, layout);
|
||||
}
|
||||
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
tile_access_iterator_.add_pointer_offset(pointer_offset);
|
||||
}
|
||||
|
||||
/// Advances to the next tile in memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorStridedDgrad &operator++() {
|
||||
tile_access_iterator_.advance();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances to the next tile in memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorStridedDgrad operator++(int) {
|
||||
TileIteratorStridedDgrad self(*this);
|
||||
operator++();
|
||||
return self;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
||||
|
||||
frag.clear();
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[c + s * ThreadMap::Iterations::kContiguous],
|
||||
tile_access_iterator_.get() + pointer_offset,
|
||||
tile_access_iterator_.valid()
|
||||
);
|
||||
|
||||
++tile_access_iterator_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
tile_access_iterator_.set_iteration_index(0);
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance() {
|
||||
tile_access_iterator_.advance();
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(ConvProblemSize const &problem_size) {
|
||||
|
||||
// dispatch to iterator implementation
|
||||
return TileAccessIterator::can_implement(problem_size);
|
||||
}
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
|
||||
@ -243,6 +243,7 @@ public:
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
@ -250,5 +251,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -196,7 +196,7 @@ private:
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at_(int offset_npq, int k) const {
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
// The subsequent fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
//
|
||||
// int npq = offset_npq;
|
||||
|
||||
@ -355,6 +355,145 @@ struct Conv3dDgradFilterIteratorOptimizedParams {
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters object for Conv3d WGRAD OutputGradient iterator
|
||||
struct Conv3dWgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNDHWC;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
Layout layout;
|
||||
|
||||
int NZPQ; // precomputd product of N*Z*P*Q for clearing predicates
|
||||
int ZPQ; // product of Z*P*Q
|
||||
unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ
|
||||
unsigned zpq_shr; // in device code.
|
||||
|
||||
int PQ; // product of P*Q
|
||||
unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ
|
||||
unsigned pq_shr; // in device code.
|
||||
|
||||
unsigned q_mul; // precomputed quantities for fast computation of div/% by Q
|
||||
unsigned q_shr; // in device code.
|
||||
|
||||
LongIndex offset_next_strided; // offset in units of bytes to next nzpq coordinate within tile
|
||||
LongIndex offset_next_contiguous; // offset in units of bytes to next k coordinate within tile
|
||||
LongIndex inc_next_nzpq; // offset in units of bytes to next nzpq position in subsequent tile
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradOutputGradientIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradOutputGradientIteratorOptimizedParams(
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Layout const &layout,
|
||||
int element_size_bits,
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
): layout(layout) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv3d_wgrad", "output_gradient",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
// Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8
|
||||
offset_next_strided = (threadmap_delta.strided() * layout.stride()[0])
|
||||
* element_size_bits / 8;
|
||||
|
||||
offset_next_contiguous = (threadmap_delta.contiguous())
|
||||
* element_size_bits / 8;
|
||||
|
||||
inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0])
|
||||
* element_size_bits / 8;
|
||||
|
||||
// Precompute several quantities for fast modulo arithmetic.
|
||||
NZPQ = problem_size.N * problem_size.Z * problem_size.P * problem_size.Q;
|
||||
ZPQ = problem_size.Z * problem_size.P * problem_size.Q;
|
||||
find_divisor(zpq_mul, zpq_shr, ZPQ);
|
||||
|
||||
PQ = problem_size.P * problem_size.Q;
|
||||
find_divisor(pq_mul, pq_shr, PQ);
|
||||
|
||||
find_divisor(q_mul, q_shr, problem_size.Q);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters object for Conv3d WGRAD Activation Tile Access Iterator
|
||||
struct Conv3dWgradActivationIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNDHWC;
|
||||
|
||||
Layout layout;
|
||||
|
||||
int RSC; // product of R*S*C
|
||||
unsigned rsc_mul; // precomputed quantities for fast computation of div/% by RSC
|
||||
unsigned rsc_shr; // in device code.
|
||||
|
||||
int SC; // product of S*C
|
||||
unsigned sc_mul; // precomputed quantities for fast computation of div/% by SC
|
||||
unsigned sc_shr; // in device code.
|
||||
|
||||
unsigned c_mul; // precomputed quantities for fast computation of div/% by C
|
||||
unsigned c_shr; // in device code.
|
||||
|
||||
int ZPQ; // product of Z*P*Q
|
||||
unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ
|
||||
unsigned zpq_shr; // in device code.
|
||||
|
||||
int PQ; // product of P*Q
|
||||
unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ
|
||||
unsigned pq_shr; // in device code.
|
||||
|
||||
unsigned q_mul; // precomputed quantities for fast computation of div/% by Q
|
||||
unsigned q_shr; // in device code.
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradActivationIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradActivationIteratorOptimizedParams(
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Layout const &layout,
|
||||
int element_size_bits,
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
): layout(layout) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv3d_wgrad", "activation",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
// Precompute several quantities for fast modulo arithmetic.
|
||||
RSC = problem_size.R * problem_size.S * problem_size.C;
|
||||
find_divisor(rsc_mul, rsc_shr, RSC);
|
||||
|
||||
SC = problem_size.S * problem_size.C;
|
||||
find_divisor(sc_mul, sc_shr, SC);
|
||||
|
||||
find_divisor(c_mul, c_shr, problem_size.C);
|
||||
|
||||
ZPQ = problem_size.Z * problem_size.P * problem_size.Q;
|
||||
find_divisor(zpq_mul, zpq_shr, ZPQ);
|
||||
|
||||
PQ = problem_size.P * problem_size.Q;
|
||||
find_divisor(pq_mul, pq_shr, PQ);
|
||||
|
||||
find_divisor(q_mul, q_shr, problem_size.Q);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -45,6 +45,7 @@
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv3d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -86,62 +87,28 @@ public:
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
int RSC; // product of R*S*C
|
||||
unsigned rsc_mul; // precomputed quantities for fast computation of div/% by RSC
|
||||
unsigned rsc_shr; // in device code.
|
||||
|
||||
int SC; // product of S*C
|
||||
unsigned sc_mul; // precomputed quantities for fast computation of div/% by SC
|
||||
unsigned sc_shr; // in device code.
|
||||
|
||||
unsigned c_mul; // precomputed quantities for fast computation of div/% by C
|
||||
unsigned c_shr; // in device code.
|
||||
|
||||
int ZPQ; // product of Z*P*Q
|
||||
unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ
|
||||
unsigned zpq_shr; // in device code.
|
||||
|
||||
int PQ; // product of P*Q
|
||||
unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ
|
||||
unsigned pq_shr; // in device code.
|
||||
|
||||
unsigned q_mul; // precomputed quantities for fast computation of div/% by Q
|
||||
unsigned q_shr; // in device code.
|
||||
|
||||
struct Params : Conv3dWgradActivationIteratorOptimizedParams {
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
Params() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
Params(Conv3dWgradActivationIteratorOptimizedParams const &base)
|
||||
: Conv3dWgradActivationIteratorOptimizedParams(base) {}
|
||||
|
||||
// Precompute several quantities for fast modulo arithmetic.
|
||||
RSC = problem_size.R * problem_size.S * problem_size.C;
|
||||
find_divisor(rsc_mul, rsc_shr, RSC);
|
||||
|
||||
SC = problem_size.S * problem_size.C;
|
||||
find_divisor(sc_mul, sc_shr, SC);
|
||||
|
||||
find_divisor(c_mul, c_shr, problem_size.C);
|
||||
|
||||
ZPQ = problem_size.Z * problem_size.P * problem_size.Q;
|
||||
find_divisor(zpq_mul, zpq_shr, ZPQ);
|
||||
|
||||
PQ = problem_size.P * problem_size.Q;
|
||||
find_divisor(pq_mul, pq_shr, PQ);
|
||||
|
||||
find_divisor(q_mul, q_shr, problem_size.Q);
|
||||
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Conv3dProblemSize const &problem_size, Layout const &layout)
|
||||
: Conv3dWgradActivationIteratorOptimizedParams(
|
||||
problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::kElementsPerAccess,
|
||||
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
|
||||
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
@ -45,6 +45,7 @@
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv3d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -86,61 +87,29 @@ public:
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
int NZPQ; // precomputd product of N*Z*P*Q for clearing predicates
|
||||
int ZPQ; // product of Z*P*Q
|
||||
unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ
|
||||
unsigned zpq_shr; // in device code.
|
||||
|
||||
int PQ; // product of P*Q
|
||||
unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ
|
||||
unsigned pq_shr; // in device code.
|
||||
|
||||
unsigned q_mul; // precomputed quantities for fast computation of div/% by Q
|
||||
unsigned q_shr; // in device code.
|
||||
|
||||
LongIndex offset_next_strided; // offset in units of bytes to next nzpq coordinate within tile
|
||||
LongIndex offset_next_contiguous; // offset in units of bytes to next k coordinate within tile
|
||||
LongIndex inc_next_nzpq; // offset in units of bytes to next nzpq position in subsequent tile
|
||||
|
||||
struct Params : Conv3dWgradOutputGradientIteratorOptimizedParams {
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
Params(Conv3dWgradOutputGradientIteratorOptimizedParams const &base)
|
||||
: Conv3dWgradOutputGradientIteratorOptimizedParams(base) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
// Incremental offsets in unites of bytes (number of elements) * sizeof_bits<Element>::value / 8
|
||||
offset_next_strided = (ThreadMap::Delta::kStrided * layout.stride()[0])
|
||||
* sizeof_bits<Element>::value / 8;
|
||||
|
||||
offset_next_contiguous = (ThreadMap::Delta::kContiguous)
|
||||
* sizeof_bits<Element>::value / 8;
|
||||
|
||||
inc_next_nzpq = (Shape::kColumn * problem_size.split_k_slices * layout.stride()[0])
|
||||
* sizeof_bits<Element>::value / 8;
|
||||
|
||||
// Precompute several quantities for fast modulo arithmetic.
|
||||
NZPQ = problem_size.N * problem_size.Z * problem_size.P * problem_size.Q;
|
||||
ZPQ = problem_size.Z * problem_size.P * problem_size.Q;
|
||||
find_divisor(zpq_mul, zpq_shr, ZPQ);
|
||||
|
||||
PQ = problem_size.P * problem_size.Q;
|
||||
find_divisor(pq_mul, pq_shr, PQ);
|
||||
|
||||
find_divisor(q_mul, q_shr, problem_size.Q);
|
||||
|
||||
}
|
||||
};
|
||||
Params(Conv3dProblemSize const &problem_size, Layout const &layout)
|
||||
: Conv3dWgradOutputGradientIteratorOptimizedParams(
|
||||
problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::kElementsPerAccess,
|
||||
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
|
||||
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
|
||||
@ -377,7 +377,7 @@ public:
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
|
||||
166
include/cutlass/conv/threadblock/threadblock_swizzle.h
Normal file
166
include/cutlass/conv/threadblock/threadblock_swizzle.h
Normal file
@ -0,0 +1,166 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements several possible threadblock-swizzling functions mapping blockIdx to
|
||||
Convolution problems.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int get_strided_dgrad_tile_m(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size,
|
||||
int tile_size_m) {
|
||||
|
||||
// CTAs in M dimension per starting filter position
|
||||
int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, tile_size_m);
|
||||
|
||||
// Inflate number of CTAs in M dimension to cover every strating filter position even those that
|
||||
// may fall out of valid MMA (Dy * w) but are needed to apply epilogue (beta * Dx_source)
|
||||
// and point-wise fusion
|
||||
int tile_m = tile_m_per_filter * int(problem_size.stride().product());
|
||||
|
||||
// There is a possible performance optimization here that leads up to 2x speeds than the current
|
||||
// CUTLASS strided dgrad performance for stride > filter, i.e., stride={2x2} and filter={1x1})
|
||||
//
|
||||
// * Optimization *
|
||||
// Only launch CTAs in M dimenstion which contribute to a row in Dx output
|
||||
//
|
||||
//
|
||||
// * Constraints *
|
||||
// (A) stride <= filter, for example, stride={2x2} and filter={3x3}:
|
||||
// - (A.1): There are no constraints for this case and the optimization does
|
||||
// affect this case functionality or performance.
|
||||
// (B) stride > filter, for example, stride={2x2} and filter={1x1}:
|
||||
// - (B.1): Dx output tensor should be zero initialized
|
||||
// - (B.2): The kernel epilogue cannot apply beta. Thus, beta should be zero
|
||||
|
||||
return tile_m;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Threadblock swizzling function for strided dgrad convolution
|
||||
struct StridedDgradHorizontalThreadblockSwizzle :
|
||||
public gemm::threadblock::GemmHorizontalThreadblockSwizzle {
|
||||
|
||||
using Base = gemm::threadblock::GemmHorizontalThreadblockSwizzle;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
StridedDgradHorizontalThreadblockSwizzle() { }
|
||||
|
||||
/// Returns the shape of the problem in units of logical tiles
|
||||
/// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC)
|
||||
CUTLASS_HOST_DEVICE
|
||||
gemm::GemmCoord get_tiled_shape(
|
||||
cutlass::conv::Operator conv_operator,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size,
|
||||
gemm::GemmCoord tile_size,
|
||||
int split_k_slices) const {
|
||||
|
||||
gemm::GemmCoord implicit_gemm_problem_size =
|
||||
cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size);
|
||||
|
||||
// compute number of tiles in m dimension
|
||||
int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m());
|
||||
|
||||
// compute number of tiles in n dimenstion
|
||||
int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n();
|
||||
|
||||
return gemm::GemmCoord(
|
||||
tile_m,
|
||||
tile_n,
|
||||
split_k_slices);
|
||||
}
|
||||
|
||||
/// Returns the shape of the problem in units of logical tiles
|
||||
/// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape())
|
||||
private:
|
||||
using Base::get_tiled_shape;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Threadblock swizzling function for strided dgrad convolution
|
||||
template <int N = 1>
|
||||
struct StridedDgradIdentityThreadblockSwizzle :
|
||||
public gemm::threadblock::GemmIdentityThreadblockSwizzle<N> {
|
||||
|
||||
using Base = gemm::threadblock::GemmIdentityThreadblockSwizzle<N>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
StridedDgradIdentityThreadblockSwizzle() { }
|
||||
|
||||
/// Returns the shape of the problem in units of logical tiles
|
||||
/// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC)
|
||||
CUTLASS_HOST_DEVICE
|
||||
gemm::GemmCoord get_tiled_shape(
|
||||
cutlass::conv::Operator conv_operator,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size,
|
||||
gemm::GemmCoord tile_size,
|
||||
int split_k_slices) const {
|
||||
|
||||
gemm::GemmCoord implicit_gemm_problem_size =
|
||||
cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size);
|
||||
|
||||
// compute number of tiles in m dimension
|
||||
int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m());
|
||||
|
||||
// compute number of tiles in n dimenstion
|
||||
int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n();
|
||||
|
||||
return gemm::GemmCoord(
|
||||
tile_m,
|
||||
tile_n,
|
||||
split_k_slices);
|
||||
}
|
||||
|
||||
|
||||
/// Returns the shape of the problem in units of logical tiles
|
||||
/// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape())
|
||||
private:
|
||||
using Base::get_tiled_shape;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -412,39 +412,61 @@ Coord<Rank, Index> operator/(Coord<Rank, Index> coord, Index s) {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to make a 2-element coordinate
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<1> make_Coord(int _0) {
|
||||
int values[1] = {_0};
|
||||
return Coord<1>(values);
|
||||
Coord<1, T> make_Coord(T _0) {
|
||||
T values[1] = {_0};
|
||||
return Coord<1, T>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 2-element coordinate
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> make_Coord(int _0, int _1) {
|
||||
int values[2] = {_0, _1};
|
||||
return Coord<2>(values);
|
||||
Coord<2, T> make_Coord(T _0, T _1) {
|
||||
T values[2] = {_0, _1};
|
||||
return Coord<2, T>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 3-element coordinate
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> make_Coord(int _0, int _1, int _2) {
|
||||
int values[3] = {_0, _1, _2};
|
||||
return Coord<3>(values);
|
||||
Coord<3, T> make_Coord(T _0, T _1, T _2) {
|
||||
T values[3] = {_0, _1, _2};
|
||||
return Coord<3, T>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 4-element coordinate
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
|
||||
int values[4] = {_0, _1, _2, _3};
|
||||
return Coord<4>(values);
|
||||
Coord<4, T> make_Coord(T _0, T _1, T _2, T _3) {
|
||||
T values[4] = {_0, _1, _2, _3};
|
||||
return Coord<4, T>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 5-element coordinate
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<5> make_Coord(int _0, int _1, int _2, int _3, int _4) {
|
||||
int values[5] = {_0, _1, _2, _3, _4};
|
||||
return Coord<5>(values);
|
||||
Coord<5, T> make_Coord(T _0, T _1, T _2, T _3, T _4) {
|
||||
T values[5] = {_0, _1, _2, _3, _4};
|
||||
return Coord<5, T>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 1-element coordinate
|
||||
template <int N, typename T>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<N, T>make_Coord_with_padding(T _0) {
|
||||
Coord<N, T> coord;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = N - 1; i > 0; --i) {
|
||||
coord[i] = 0;
|
||||
}
|
||||
|
||||
coord[0] = _0;
|
||||
|
||||
return coord;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -34,6 +34,8 @@
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix.h"
|
||||
#include "cutlass/quaternion.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
@ -150,6 +152,45 @@ std::ostream & operator<<(std::ostream &out, MatrixShape<Row, Column> const &mat
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
/// Prints matrix to ostream
|
||||
template <typename Element, int Rows, int Columns>
|
||||
std::ostream & operator<<(std::ostream &out, Matrix<Element, Rows, Columns> const &rhs) {
|
||||
|
||||
for (int i = 0; i < Rows; ++i) {
|
||||
for (int j = 0; j < Columns; ++j) {
|
||||
ScalarIO<Element> element(rhs.at(i, j));
|
||||
out << (j ? ", " : "") << element;
|
||||
}
|
||||
out << "\\n";
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream &operator<<(std::ostream &out, Quaternion<T> const &rhs) {
|
||||
|
||||
out << ScalarIO<T>(rhs.w()) << " ";
|
||||
if (rhs.x() >= 0) {
|
||||
out << "+";
|
||||
}
|
||||
|
||||
out << ScalarIO<T>(rhs.x()) << "*i ";
|
||||
if (rhs.y() >= 0) {
|
||||
out << "+";
|
||||
}
|
||||
|
||||
out << ScalarIO<T>(rhs.y()) << "*j ";
|
||||
if (rhs.z() >= 0) {
|
||||
out << "+";
|
||||
}
|
||||
|
||||
out << ScalarIO<T>(rhs.z()) << "*k";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// stream operators for cutlass::gemm namespace //
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -141,26 +141,6 @@ static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2;
|
||||
static const int NUM_THREADS_PER_QUAD = 4;
|
||||
static const int NUM_THREADS_PER_QUAD_PAIR = NUM_THREADS_PER_QUAD * 2;
|
||||
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
||||
|
||||
/// Computes laneId within a warp
|
||||
CUTLASS_DEVICE
|
||||
int LaneId() {
|
||||
int ret;
|
||||
asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : );
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Computes SM number the thread is running on
|
||||
CUTLASS_DEVICE
|
||||
int SmId() {
|
||||
int ret;
|
||||
asm ("mov.u32 %0, %%smid;" : "=r"(ret) : );
|
||||
return ret;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -180,6 +180,7 @@ struct GELU<Array<T, N> > {
|
||||
// GELU operator implemented using the Taylor series approximation
|
||||
template <typename T>
|
||||
struct GELU_taylor {
|
||||
static const bool kIsHeavy=true;
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &z) const {
|
||||
|
||||
@ -193,6 +194,7 @@ struct GELU_taylor {
|
||||
|
||||
template <typename T, int N>
|
||||
struct GELU_taylor<Array<T, N> > {
|
||||
static const bool kIsHeavy=true;
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> y;
|
||||
@ -250,4 +252,3 @@ struct dGELU<Array<T, N> > {
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -65,6 +65,8 @@ public:
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
static bool const kIsHeavy = false;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
|
||||
@ -49,7 +49,9 @@ namespace thread {
|
||||
///
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
int Count, ///< Number of elements computed per operation.
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
||||
@ -146,6 +148,8 @@ public:
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
@ -167,11 +171,17 @@ public:
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
if (Scale == ScaleType::Nothing)
|
||||
return destination_converter(converted_accumulator);
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
|
||||
// Perform binary operations
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_source;
|
||||
@ -180,13 +190,10 @@ public:
|
||||
if (Scale == ScaleType::NoBetaScaling)
|
||||
intermediate = converted_source;
|
||||
else
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
@ -198,17 +205,20 @@ public:
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
if (Scale == ScaleType::Nothing)
|
||||
return destination_converter(converted_accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
ComputeFragment intermediate;
|
||||
multiplies<ComputeFragment> mul_accumulator;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
};
|
||||
|
||||
@ -0,0 +1,251 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Functor performing linear combination operations used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This base class is meant to define the concept required of the
|
||||
/// EpilogueWithBroadcast::OutputOp
|
||||
template <
|
||||
typename ElementC_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementCompute_,
|
||||
typename ElementZ_,
|
||||
typename ElementT_,
|
||||
int ElementsPerAccess,
|
||||
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
||||
typename BinaryOp_ = plus<ElementCompute_>
|
||||
>
|
||||
class LinearCombinationBiasElementwise {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementC_;
|
||||
using ElementC = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementZ = ElementZ_;
|
||||
using ElementT = ElementT_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
using ElementwiseOp = ElementwiseOp_;
|
||||
using BinaryOp = BinaryOp_;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
|
||||
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
||||
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
||||
|
||||
using FragmentOutput = FragmentZ;
|
||||
|
||||
static bool const kIsHeavy = ElementwiseOp::kIsHeavy;
|
||||
|
||||
/// If true, the 'Z' tensor is stored
|
||||
static bool const kStoreZ = true;
|
||||
|
||||
/// If true, the 'T' tensor is stored
|
||||
static bool const kStoreT = true;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta
|
||||
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha
|
||||
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
bool skip_elementwise_;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor from Params
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationBiasElementwise(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
skip_elementwise_ = false;
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
skip_elementwise_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the operation when is_source_needed() is true
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentC const &frag_C,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
ElementwiseOp elementwise_op;
|
||||
BinaryOp binary_op;
|
||||
|
||||
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
||||
FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
|
||||
FragmentCompute result_Z;
|
||||
FragmentCompute result_T;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kElementsPerAccess; ++i) {
|
||||
ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]);
|
||||
result_Z[i] = z;
|
||||
result_T[i] = skip_elementwise_ ? z : elementwise_op(z);
|
||||
}
|
||||
|
||||
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
|
||||
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
||||
frag_T = convert_t(result_T);
|
||||
}
|
||||
|
||||
/// Applies the operation when is_source_needed() is false
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
ElementwiseOp elementwise_op;
|
||||
BinaryOp binary_op;
|
||||
|
||||
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
||||
FragmentCompute result_Z;
|
||||
FragmentCompute result_T;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kElementsPerAccess; ++i) {
|
||||
ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
|
||||
result_Z[i] = z;
|
||||
result_T[i] = skip_elementwise_ ? z : elementwise_op(z);
|
||||
}
|
||||
|
||||
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
|
||||
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
||||
frag_T = convert_t(result_T);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -28,6 +28,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
@ -41,6 +43,146 @@ namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Element, int ElementsPerAccess>
|
||||
struct ArrayMaximum {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<Element, ElementsPerAccess> operator()(
|
||||
Array<Element, ElementsPerAccess> const &lhs,
|
||||
Array<Element, ElementsPerAccess> const &rhs) const {
|
||||
|
||||
Array<Element, ElementsPerAccess> result;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
result[i] = fmax(lhs[i], rhs[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <int ElementsPerAccess>
|
||||
struct ArrayMaximum<half_t, ElementsPerAccess> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Array<half_t, ElementsPerAccess> operator()(
|
||||
Array<half_t, ElementsPerAccess> const &lhs,
|
||||
Array<half_t, ElementsPerAccess> const &rhs) const {
|
||||
|
||||
Array<half_t, ElementsPerAccess> result;
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
int const kVectorCount = ElementsPerAccess / 2;
|
||||
|
||||
|
||||
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data());
|
||||
__half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(rhs.raw_data());
|
||||
__half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kVectorCount; ++i) {
|
||||
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]);
|
||||
}
|
||||
|
||||
#else
|
||||
__half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data());
|
||||
__half const *rhs_ptr = reinterpret_cast<__half const *>(rhs.raw_data());
|
||||
__half *res_ptr = reinterpret_cast<__half *>(result.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Array<half_t, ElementsPerAccess> operator()(
|
||||
Array<half_t, ElementsPerAccess> const &lhs,
|
||||
half_t const &rhs) const {
|
||||
|
||||
Array<half_t, ElementsPerAccess> result;
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
int const kVectorCount = ElementsPerAccess / 2;
|
||||
|
||||
|
||||
__half rhs_raw = reinterpret_cast<__half const &>(rhs);
|
||||
__half2 rhs_pair = __half2half2(rhs_raw);
|
||||
|
||||
__half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data());
|
||||
__half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kVectorCount; ++i) {
|
||||
res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data());
|
||||
__half const rhs_raw = reinterpret_cast<__half const &>(rhs);
|
||||
__half *res_ptr = reinterpret_cast<__half *>(result.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element, int ElementsPerAccess>
|
||||
struct ReluConditional {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
bool conditional[],
|
||||
Array<Element, ElementsPerAccess> const &fragment,
|
||||
Element threshold) const {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
conditional[i] = !(fragment[i] < threshold);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int ElementsPerAccess>
|
||||
struct ReluConditional<half_t, ElementsPerAccess> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
bool conditional[],
|
||||
Array<half_t, ElementsPerAccess> const &fragment,
|
||||
half_t threshold) const {
|
||||
|
||||
__half y = reinterpret_cast<__half const &>(threshold);
|
||||
__half const *x = reinterpret_cast<__half const *>(fragment.raw_data());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ElementsPerAccess; ++i) {
|
||||
conditional[i] = !__hlt(x[i], y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This is a partial specialization for fused Bias and ReLU. It supports the option of packing
|
||||
@ -94,8 +236,11 @@ public:
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
ElementCompute threshold; ///< ReLu threshold
|
||||
ElementZ threshold; ///< ReLu threshold
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -112,16 +257,19 @@ public:
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
ElementCompute threshold = ElementCompute()
|
||||
ElementCompute threshold_ = ElementCompute()
|
||||
):
|
||||
alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr), threshold(threshold) {
|
||||
alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
NumericConverter<ElementZ, ElementCompute> convert_threshold;
|
||||
|
||||
threshold = convert_threshold(threshold_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha
|
||||
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr), threshold(threshold) {
|
||||
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr), threshold(ElementZ()) {
|
||||
|
||||
}
|
||||
|
||||
@ -129,17 +277,20 @@ public:
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr,
|
||||
ElementCompute threshold = ElementCompute()
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), threshold(threshold) {
|
||||
ElementCompute threshold_ = ElementCompute()
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
NumericConverter<ElementZ, ElementCompute> convert_threshold;
|
||||
|
||||
threshold = convert_threshold(threshold_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr), threshold(threshold) {
|
||||
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr), threshold(ElementZ()) {
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
@ -150,7 +301,7 @@ private:
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
ElementCompute threshold_;
|
||||
ElementZ threshold_;
|
||||
|
||||
public:
|
||||
|
||||
@ -179,6 +330,12 @@ public:
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
// set to NaN to make ReLU no-op for all except last k partitions
|
||||
int64_t allones = -1;
|
||||
threshold_ = reinterpret_cast<ElementZ const &>(allones);
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the operation when is_source_needed() is true
|
||||
@ -201,18 +358,27 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kElementsPerAccess; ++i) {
|
||||
ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]);
|
||||
|
||||
bool condition = !(z < threshold_);
|
||||
z = fmax(z, threshold_);
|
||||
ElementCompute z = alpha_ * tmp_Accum[i];
|
||||
z += beta_ * tmp_C[i];
|
||||
|
||||
z = binary_op(z, V[i]);
|
||||
result_Z[i] = z;
|
||||
conditions[i] = condition;
|
||||
}
|
||||
|
||||
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
|
||||
//
|
||||
// Compute condition
|
||||
//
|
||||
|
||||
detail::ReluConditional<ElementZ, kElementsPerAccess> relu_conditional;
|
||||
relu_conditional(conditions, frag_Z, threshold_);
|
||||
|
||||
detail::ArrayMaximum<ElementZ, kElementsPerAccess> maximum_op;
|
||||
frag_Z = maximum_op(frag_Z, threshold_);
|
||||
|
||||
if (kStoreT) {
|
||||
PackPredicates<kElementsPerAccess> pack_predicates;
|
||||
frag_T = pack_predicates(conditions);
|
||||
@ -238,17 +404,29 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kElementsPerAccess; ++i) {
|
||||
ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
|
||||
|
||||
bool condition = !(z < threshold_);
|
||||
z = fmax(z, threshold_);
|
||||
|
||||
result_Z[i] = z;
|
||||
conditions[i] = condition;
|
||||
}
|
||||
|
||||
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
|
||||
//
|
||||
// Compute condition
|
||||
//
|
||||
|
||||
detail::ReluConditional<ElementZ, kElementsPerAccess> relu_conditional;
|
||||
relu_conditional(conditions, frag_Z, threshold_);
|
||||
|
||||
detail::ArrayMaximum<ElementZ, kElementsPerAccess> maximum_op;
|
||||
frag_Z = maximum_op(frag_Z, threshold_);
|
||||
|
||||
//
|
||||
// Compute conditions
|
||||
//
|
||||
|
||||
//
|
||||
// Store
|
||||
//
|
||||
if (kStoreT) {
|
||||
PackPredicates<kElementsPerAccess> pack_predicates;
|
||||
frag_T = pack_predicates(conditions);
|
||||
|
||||
@ -43,6 +43,17 @@ namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Single source of truth for whether to unroll for `LinearCombinationClamp()`
|
||||
constexpr bool LinearCombinationClampIsHeavy() {
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements then clamps the output before
|
||||
/// converting to the output element type.
|
||||
///
|
||||
@ -51,6 +62,8 @@ namespace thread {
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
||||
@ -71,6 +84,8 @@ public:
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy();
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
@ -282,6 +297,8 @@ public:
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy();
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
@ -396,10 +413,9 @@ public:
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
||||
|
||||
scaled_accumulator = compute_converter(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
|
||||
@ -427,10 +443,9 @@ public:
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
||||
|
||||
scaled_accumulator = compute_converter(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
|
||||
@ -487,6 +502,8 @@ class FastLinearCombinationClamp {
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
static bool const kIsHeavy = false;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
/// scales accumulators
|
||||
|
||||
244
include/cutlass/epilogue/thread/linear_combination_dgelu.h
Normal file
244
include/cutlass/epilogue/thread/linear_combination_dgelu.h
Normal file
@ -0,0 +1,244 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
|
||||
\brief Functor performing linear combination followed by dGelu operation
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/half.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/constants.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
template <
|
||||
typename ElementCompute_, ///< Data type returned by this functor
|
||||
typename ElementAccumulator_, ///< Data type of accumulators
|
||||
typename ElementSource_, ///< Data type of source tensor
|
||||
typename ElementTensor_, ///< Data type of additional tensor
|
||||
int Count, ///< Number of elements computed per operation
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
class LinearCombinationDGelu {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementSource_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementSource = ElementSource_;
|
||||
using ElementTensor = ElementTensor_;
|
||||
|
||||
static bool const kIsHeavy = true;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentCompute = Array<ElementCompute, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using FragmentSource = Array<ElementSource, kCount>;
|
||||
using FragmentTensor = Array<ElementTensor, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
threshold(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
ElementCompute threshold_;
|
||||
bool participates_in_reduction_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationDGelu(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
threshold_ = params.threshold;
|
||||
participates_in_reduction_ = true;
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// Returns true if the threadblock computes the reduction
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool participates_in_reduction() const {
|
||||
return participates_in_reduction_;
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
// set to NaN to make ReLU no-op for all except last k partitions
|
||||
int64_t allones = -1;
|
||||
threshold_ = reinterpret_cast<ElementCompute const &>(allones);
|
||||
// Avoid computing the reduction if this isn't the final Split-K slice
|
||||
participates_in_reduction_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentCompute operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentSource const &source,
|
||||
FragmentTensor const &tensor) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_source = source_converter(source);
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_add_source;
|
||||
multiply_add<FragmentCompute> mul_add_accumulator;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
dGELU<ElementCompute> gelu_op;
|
||||
|
||||
// dGelu
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i]));
|
||||
}
|
||||
|
||||
return intermediate;
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentCompute operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentTensor const &tensor) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_accumulator;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
dGELU<ElementCompute> gelu_op;
|
||||
|
||||
// dGelu with conversion
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i]));
|
||||
}
|
||||
|
||||
return intermediate;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
446
include/cutlass/epilogue/thread/linear_combination_drelu.h
Normal file
446
include/cutlass/epilogue/thread/linear_combination_drelu.h
Normal file
@ -0,0 +1,446 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/half.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
template <
|
||||
typename ElementCompute_, ///< Data type returned by this functor
|
||||
typename ElementAccumulator_, ///< Data type of accumulators
|
||||
typename ElementSource_, ///< Data type of source tensor
|
||||
typename ElementTensor_, ///< Data type of additional tensor
|
||||
int Count, ///< Number of elements computed per operation
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
class LinearCombinationDRelu {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementSource_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementSource = ElementSource_;
|
||||
using ElementTensor = ElementTensor_;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentCompute = Array<ElementCompute, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using FragmentSource = Array<ElementSource, kCount>;
|
||||
using FragmentTensor = Array<ElementTensor, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
threshold(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
ElementTensor threshold_;
|
||||
bool participates_in_reduction_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationDRelu(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
threshold_ = ElementTensor(params.threshold);
|
||||
participates_in_reduction_ = true;
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// Returns true if the threadblock computes the reduction
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool participates_in_reduction() const {
|
||||
return participates_in_reduction_;
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
// set to NaN to make ReLU no-op for all except last k partitions
|
||||
int64_t allones = -1;
|
||||
threshold_ = reinterpret_cast<ElementTensor const &>(allones);
|
||||
participates_in_reduction_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentCompute operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentSource const &source,
|
||||
FragmentTensor const &tensor) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_source = source_converter(source);
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_add_source;
|
||||
multiply_add<FragmentCompute> mul_add_accumulator;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
// dReLU = (cond ? dy : 0)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
ElementTensor cond = tensor[i];
|
||||
if (cond <= threshold_) {
|
||||
intermediate[i] = ElementCompute();
|
||||
}
|
||||
}
|
||||
|
||||
return intermediate;
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentCompute operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentTensor const &tensor) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_accumulator;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
// dReLU = (cond ? dy : 0)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
ElementTensor cond = tensor[i];
|
||||
if (cond <= threshold_) {
|
||||
intermediate[i] = ElementCompute();
|
||||
}
|
||||
}
|
||||
|
||||
return intermediate;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
template <
|
||||
typename ElementCompute_, ///< Data type returned by this functor
|
||||
typename ElementAccumulator_, ///< Data type of accumulators
|
||||
typename ElementSource_, ///< Data type of source tensor
|
||||
int Count, ///< Number of elements computed per operation
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
class LinearCombinationDReluConditionalBits {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementSource_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementSource = ElementSource_;
|
||||
using ElementTensor = uint1b_t;
|
||||
|
||||
static bool const kIsHeavy = false;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentCompute = Array<ElementCompute, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using FragmentSource = Array<ElementSource, kCount>;
|
||||
using FragmentTensor = Array<ElementTensor, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta
|
||||
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
FragmentTensor predicate_mask_;
|
||||
bool participates_in_reduction_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationDReluConditionalBits(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
participates_in_reduction_ = true;
|
||||
predicate_mask_.clear();
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// Returns true if the threadblock computes the reduction
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool participates_in_reduction() const {
|
||||
return participates_in_reduction_;
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
predicate_mask_.clear();
|
||||
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
// Avoid computing the reduction if this isn't the final Split-K slice
|
||||
participates_in_reduction_ = false;
|
||||
|
||||
bit_not<FragmentTensor> not_op;
|
||||
predicate_mask_ = not_op(predicate_mask_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_DEVICE
|
||||
FragmentCompute operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentSource const &source,
|
||||
FragmentTensor const &tensor) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_source = source_converter(source);
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_add_source;
|
||||
multiply_add<FragmentCompute> mul_add_accumulator;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
bit_or<FragmentTensor> or_op;
|
||||
|
||||
FragmentTensor predicates = or_op(tensor, predicate_mask_);
|
||||
|
||||
// Obtain from packed bits
|
||||
bool conditions[kCount];
|
||||
UnpackPredicates<kCount> unpack_predicates;
|
||||
|
||||
unpack_predicates(conditions, predicates);
|
||||
|
||||
// dReLU = (cond ? dy : 0)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
if (!conditions[i]) {
|
||||
intermediate[i] = ElementCompute();
|
||||
}
|
||||
}
|
||||
|
||||
return intermediate;
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentCompute operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentTensor const &tensor) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_accumulator;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
bit_or<FragmentTensor> or_op;
|
||||
|
||||
FragmentTensor predicates = or_op(tensor, predicate_mask_);
|
||||
|
||||
// Obtain from packed bits
|
||||
bool conditions[kCount];
|
||||
UnpackPredicates<kCount> unpack_predicates;
|
||||
|
||||
unpack_predicates(conditions, predicates);
|
||||
|
||||
// dReLU = (cond ? dy : 0)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
if (!conditions[i]) {
|
||||
intermediate[i] = ElementCompute();
|
||||
}
|
||||
}
|
||||
|
||||
return intermediate;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -51,6 +51,8 @@ namespace thread {
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
@ -62,6 +64,8 @@ public:
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
|
||||
static bool const kIsHeavy = true;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
@ -134,10 +138,11 @@ public:
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
CUTLASS_UNUSED(k_partition_count);
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
CUTLASS_UNUSED(k_partition_count);
|
||||
}
|
||||
|
||||
/// Computes: D = gelu( alpha * accumulator + beta * source )
|
||||
|
||||
@ -52,6 +52,8 @@ namespace thread {
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
|
||||
@ -45,6 +45,17 @@ namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Single source of truth for whether to unroll for `LinearCombinationClamp()`
|
||||
constexpr bool LinearCombinationReluIsHeavy() {
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
@ -52,6 +63,8 @@ namespace thread {
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling
|
||||
@ -72,6 +85,8 @@ public:
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy();
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
@ -244,6 +259,8 @@ public:
|
||||
using ElementAccumulator = int;
|
||||
using ElementCompute = float;
|
||||
|
||||
static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy();
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
@ -357,10 +374,10 @@ public:
|
||||
ReLu<ComputeFragment> relu;
|
||||
|
||||
if (Scale == ScaleType::NoBetaScaling)
|
||||
intermediate = converted_source;
|
||||
intermediate = converted_source;
|
||||
else
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
// Compute threshold optionally
|
||||
@ -378,10 +395,9 @@ public:
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
||||
|
||||
scaled_accumulator = compute_converter(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
||||
@ -416,14 +432,6 @@ public:
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(threshold_, intermediate);
|
||||
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
|
||||
if (platform::is_same<ElementOutput, int32_t>::value ||
|
||||
platform::is_same<ElementOutput, uint32_t>::value ||
|
||||
platform::is_same<ElementOutput, int16_t>::value ||
|
||||
@ -436,10 +444,9 @@ public:
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
NumericArrayConverter<int, ElementCompute, kCount, Round> compute_converter;
|
||||
|
||||
scaled_accumulator = compute_converter(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
||||
|
||||
@ -51,6 +51,8 @@ namespace thread {
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
|
||||
@ -0,0 +1,228 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
|
||||
\brief Functor performing linear combination with elementwise
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/half.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/constants.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
template <
|
||||
typename ElementCompute_, ///< Data type returned by this functor
|
||||
typename ElementAccumulator_, ///< Data type of accumulators
|
||||
typename ElementSource_, ///< Data type of source tensor
|
||||
typename ElementTensor_, ///< Data type of additional tensor
|
||||
int Count, ///< Number of elements computed per operation
|
||||
///< Usually it is 128/sizeof_bits<ElementOutput_>,
|
||||
///< but we use 64 or 32 sometimes when there are not enough data to store
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
class LinearCombinationWithElementwise {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementSource_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementSource = ElementSource_;
|
||||
using ElementTensor = ElementTensor_;
|
||||
|
||||
static bool const kIsHeavy = true;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentCompute = Array<ElementCompute, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using FragmentSource = Array<ElementSource, kCount>;
|
||||
using FragmentTensor = Array<ElementTensor, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
threshold(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
ElementCompute threshold_;
|
||||
bool participates_in_reduction_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationWithElementwise(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
threshold_ = params.threshold;
|
||||
participates_in_reduction_ = true;
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// Returns true if the threadblock computes the reduction
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool participates_in_reduction() const {
|
||||
return participates_in_reduction_;
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
// set to NaN to make ReLU no-op for all except last k partitions
|
||||
int64_t allones = -1;
|
||||
threshold_ = reinterpret_cast<ElementCompute const &>(allones);
|
||||
// Avoid computing the reduction if this isn't the final Split-K slice
|
||||
participates_in_reduction_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentCompute operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentSource const &source,
|
||||
FragmentTensor const &tensor) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_source = source_converter(source);
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_add_source;
|
||||
multiply_add<FragmentCompute> mul_add_accumulator;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
return intermediate;
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentCompute operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentTensor const &tensor) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
FragmentCompute intermediate;
|
||||
|
||||
multiplies<FragmentCompute> mul_accumulator;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
return intermediate;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -41,9 +41,10 @@ namespace thread {
|
||||
/// Specifies internal data type for computation
|
||||
struct ScaleType {
|
||||
enum Kind {
|
||||
Default, // alpha x C + beta x D
|
||||
NoBetaScaling, // alpha x C + D
|
||||
OnlyAlphaScaling // alpha x C
|
||||
Default, // alpha x C + beta x D
|
||||
NoBetaScaling, // alpha x C + D
|
||||
OnlyAlphaScaling, // alpha x C
|
||||
Nothing // C
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@ -44,7 +44,6 @@
|
||||
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
|
||||
@ -55,6 +54,8 @@
|
||||
#include "cutlass/epilogue/threadblock/default_thread_map_simt.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
|
||||
@ -144,6 +145,164 @@ struct DefaultEpilogueSimt {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for SimtOps.
|
||||
template <
|
||||
typename Shape_,
|
||||
typename WarpMmaSimt_,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueSimtStridedDgrad {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaSimt = WarpMmaSimt_;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaSimt::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaSimt::ElementC;
|
||||
|
||||
//
|
||||
// Thread map
|
||||
//
|
||||
|
||||
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt<
|
||||
Shape,
|
||||
typename WarpMmaSimt::Shape,
|
||||
typename WarpMmaSimt::Policy,
|
||||
kPartitionsK,
|
||||
ElementOutput,
|
||||
kElementsPerAccess
|
||||
>::Type;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
|
||||
typename WarpMmaSimt::Shape,
|
||||
typename WarpMmaSimt::ThreadMma,
|
||||
layout::RowMajor,
|
||||
typename WarpMmaSimt::Policy
|
||||
>;
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt<
|
||||
typename WarpMmaSimt::Shape,
|
||||
typename WarpMmaSimt::ThreadMma,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename WarpMmaSimt::Policy
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
||||
typename OutputTileThreadMap::CompactedThreadMap,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
/// Hard-coded padding elements added
|
||||
using Padding = typename WarpTileIterator::Padding;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
||||
Shape,
|
||||
WarpMmaSimt,
|
||||
kPartitionsK,
|
||||
OutputTileIterator,
|
||||
AccumulatorFragmentIterator,
|
||||
WarpTileIterator,
|
||||
SharedLoadIterator,
|
||||
OutputOp,
|
||||
Padding
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for SimtOps.
|
||||
template <
|
||||
int Rank,
|
||||
typename Shape_,
|
||||
typename WarpMmaSimt_,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueSimtAffineRankN {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaSimt = WarpMmaSimt_;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaSimt::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaSimt::ElementC;
|
||||
|
||||
//
|
||||
// Thread map
|
||||
//
|
||||
|
||||
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt<
|
||||
Shape,
|
||||
typename WarpMmaSimt::Shape,
|
||||
typename WarpMmaSimt::Policy,
|
||||
kPartitionsK,
|
||||
ElementOutput,
|
||||
kElementsPerAccess
|
||||
>::Type;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
Rank
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
|
||||
typename WarpMmaSimt::Shape,
|
||||
typename WarpMmaSimt::ThreadMma,
|
||||
layout::RowMajor,
|
||||
typename WarpMmaSimt::Policy
|
||||
>;
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt<
|
||||
typename WarpMmaSimt::Shape,
|
||||
typename WarpMmaSimt::ThreadMma,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename WarpMmaSimt::Policy
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
||||
typename OutputTileThreadMap::CompactedThreadMap,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
/// Hard-coded padding elements added
|
||||
using Padding = typename WarpTileIterator::Padding;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
||||
Shape,
|
||||
WarpMmaSimt,
|
||||
kPartitionsK,
|
||||
OutputTileIterator,
|
||||
AccumulatorFragmentIterator,
|
||||
WarpTileIterator,
|
||||
SharedLoadIterator,
|
||||
OutputOp,
|
||||
Padding
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
@ -56,6 +56,8 @@
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
||||
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
||||
|
||||
@ -364,6 +366,188 @@ struct DefaultEpilogueTensorOp {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
typename Shape_,
|
||||
typename WarpMmaTensorOp_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueTensorOpStridedDgrad {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
||||
|
||||
//
|
||||
// Thread map
|
||||
//
|
||||
|
||||
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
||||
Shape,
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
kPartitionsK,
|
||||
ElementOutput,
|
||||
kElementsPerAccess
|
||||
>::Type;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,
|
||||
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
LayoutC>,
|
||||
cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
LayoutC> >::type;
|
||||
|
||||
/// Support several implementations depending on structure of epilogue
|
||||
using DefaultIterators = detail::DefaultIteratorsTensorOp<
|
||||
ElementOutput,
|
||||
ElementAccumulator,
|
||||
kElementsPerAccess,
|
||||
Shape,
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename OutputTileThreadMap::CompactedThreadMap
|
||||
>;
|
||||
|
||||
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
|
||||
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
|
||||
|
||||
/// Hard-coded padding elements added
|
||||
using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
|
||||
|
||||
static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1);
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputTileIterator,
|
||||
AccumulatorFragmentIterator,
|
||||
WarpTileIterator,
|
||||
SharedLoadIterator,
|
||||
OutputOp,
|
||||
Padding,
|
||||
kFragmentsPerIteration
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
int Rank,
|
||||
typename Shape_,
|
||||
typename WarpMmaTensorOp_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueTensorOpAffineRankN {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
||||
|
||||
//
|
||||
// Thread map
|
||||
//
|
||||
|
||||
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
||||
Shape,
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
kPartitionsK,
|
||||
ElementOutput,
|
||||
kElementsPerAccess
|
||||
>::Type;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
Rank
|
||||
>;
|
||||
|
||||
// Map to the row major iterator since the iterator selection for affineN is the same.
|
||||
using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,
|
||||
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
layout::RowMajor>,
|
||||
cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
layout::RowMajor> >::type;
|
||||
|
||||
/// Support several implementations depending on structure of epilogue
|
||||
using DefaultIterators = detail::DefaultIteratorsTensorOp<
|
||||
ElementOutput,
|
||||
ElementAccumulator,
|
||||
kElementsPerAccess,
|
||||
Shape,
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename OutputTileThreadMap::CompactedThreadMap
|
||||
>;
|
||||
|
||||
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
|
||||
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
|
||||
|
||||
/// Hard-coded padding elements added
|
||||
using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
|
||||
|
||||
static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1);
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputTileIterator,
|
||||
AccumulatorFragmentIterator,
|
||||
WarpTileIterator,
|
||||
SharedLoadIterator,
|
||||
OutputOp,
|
||||
Padding,
|
||||
kFragmentsPerIteration
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps which uses
|
||||
/// intereleaved output layout. For this case, shared memory is not needed.
|
||||
template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
|
||||
|
||||
@ -49,7 +49,9 @@
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h"
|
||||
@ -149,6 +151,174 @@ struct DefaultEpilogueVoltaTensorOp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
typename Shape_,
|
||||
typename WarpMmaTensorOp_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueVoltaTensorOpStridedDgrad {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
||||
|
||||
//
|
||||
// Thread map
|
||||
//
|
||||
|
||||
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp<
|
||||
Shape,
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
kPartitionsK,
|
||||
ElementOutput,
|
||||
kElementsPerAccess,
|
||||
ElementAccumulator
|
||||
>::Type;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
gemm::GemmShape<32, 32, 4>,
|
||||
ElementAccumulator,
|
||||
LayoutC
|
||||
>;
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
gemm::GemmShape<32, 32, 4>,
|
||||
ElementAccumulator,
|
||||
LayoutC
|
||||
>;
|
||||
|
||||
static int const kSharedMemAlignment = sizeof_bits<ElementAccumulator>::value * WarpTileIterator::kElementsPerAccess / 8;
|
||||
|
||||
static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B");
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
||||
typename OutputTileThreadMap::CompactedThreadMap,
|
||||
ElementAccumulator,
|
||||
kSharedMemAlignment
|
||||
>;
|
||||
|
||||
/// Hard-coded padding elements added
|
||||
using Padding = typename WarpTileIterator::Padding;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputTileIterator,
|
||||
AccumulatorFragmentIterator,
|
||||
WarpTileIterator,
|
||||
SharedLoadIterator,
|
||||
OutputOp,
|
||||
Padding
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
int Rank,
|
||||
typename Shape_,
|
||||
typename WarpMmaTensorOp_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueVoltaTensorOpAffineRankN {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
||||
|
||||
//
|
||||
// Thread map
|
||||
//
|
||||
|
||||
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp<
|
||||
Shape,
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
kPartitionsK,
|
||||
ElementOutput,
|
||||
kElementsPerAccess,
|
||||
ElementAccumulator
|
||||
>::Type;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
Rank
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
gemm::GemmShape<32, 32, 4>,
|
||||
ElementAccumulator,
|
||||
LayoutC
|
||||
>;
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
gemm::GemmShape<32, 32, 4>,
|
||||
ElementAccumulator,
|
||||
LayoutC
|
||||
>;
|
||||
|
||||
static int const kSharedMemAlignment = sizeof_bits<ElementAccumulator>::value * WarpTileIterator::kElementsPerAccess / 8;
|
||||
|
||||
static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B");
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
||||
typename OutputTileThreadMap::CompactedThreadMap,
|
||||
ElementAccumulator,
|
||||
kSharedMemAlignment
|
||||
>;
|
||||
|
||||
/// Hard-coded padding elements added
|
||||
using Padding = typename WarpTileIterator::Padding;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputTileIterator,
|
||||
AccumulatorFragmentIterator,
|
||||
WarpTileIterator,
|
||||
SharedLoadIterator,
|
||||
OutputOp,
|
||||
Padding
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
@ -0,0 +1,171 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename ElementOutput,
|
||||
typename ElementTensor,
|
||||
typename ElementVector,
|
||||
typename OutputOp,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueWithBroadcastTensorOp {
|
||||
|
||||
/// Use defaults related to the existing epilogue
|
||||
using Base = DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
//
|
||||
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
||||
//
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
//
|
||||
// Additional tensor tile iterator - stores t = Elementwise(z)
|
||||
//
|
||||
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementTensor
|
||||
>;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = EpilogueWithBroadcast<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputTileIterator,
|
||||
TensorTileIterator,
|
||||
ElementVector,
|
||||
typename Base::AccumulatorFragmentIterator,
|
||||
typename Base::WarpTileIterator,
|
||||
typename Base::SharedLoadIterator,
|
||||
OutputOp,
|
||||
typename Base::Padding,
|
||||
Base::kFragmentsPerIteration
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for VoltaTensorOps.
|
||||
template <
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename ElementOutput,
|
||||
typename ElementTensor,
|
||||
typename ElementVector,
|
||||
typename OutputOp,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueWithBroadcastVoltaTensorOp {
|
||||
|
||||
/// Use defaults related to the existing epilogue
|
||||
using Base = DefaultEpilogueVoltaTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
//
|
||||
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
||||
//
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
//
|
||||
// Additional tensor tile iterator - stores t = Elementwise(z)
|
||||
//
|
||||
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementTensor
|
||||
>;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = EpilogueWithBroadcast<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputTileIterator,
|
||||
TensorTileIterator,
|
||||
ElementVector,
|
||||
typename Base::AccumulatorFragmentIterator,
|
||||
typename Base::WarpTileIterator,
|
||||
typename Base::SharedLoadIterator,
|
||||
OutputOp,
|
||||
typename Base::Padding
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,161 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename ElementOutput,
|
||||
typename OutputOp,
|
||||
typename ReductionOp,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueWithReductionTensorOp {
|
||||
|
||||
/// Use defaults related to the existing epilogue
|
||||
using Base = DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Additional tensor tile iterator
|
||||
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
typename OutputOp::ElementTensor
|
||||
>;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = EpilogueWithReduction<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputTileIterator,
|
||||
TensorTileIterator,
|
||||
typename WarpMmaTensorOp::ElementC,
|
||||
typename Base::AccumulatorFragmentIterator,
|
||||
typename Base::WarpTileIterator,
|
||||
typename Base::SharedLoadIterator,
|
||||
typename Base::OutputOp,
|
||||
ReductionOp,
|
||||
typename Base::Padding
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename ElementOutput,
|
||||
typename OutputOp,
|
||||
typename ReductionOp,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueWithReductionVoltaTensorOp {
|
||||
|
||||
/// Use defaults related to the existing epilogue
|
||||
using Base = DefaultEpilogueVoltaTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Additional tensor tile iterator
|
||||
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
typename OutputOp::ElementTensor
|
||||
>;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = EpilogueWithReduction<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputTileIterator,
|
||||
TensorTileIterator,
|
||||
typename WarpMmaTensorOp::ElementC,
|
||||
typename Base::AccumulatorFragmentIterator,
|
||||
typename Base::WarpTileIterator,
|
||||
typename Base::SharedLoadIterator,
|
||||
typename Base::OutputOp,
|
||||
ReductionOp,
|
||||
typename Base::Padding
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -54,6 +54,7 @@
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/util/index_sequence.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -74,7 +75,9 @@ template <
|
||||
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
||||
typename OutputOp_, ///< Output operator
|
||||
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
||||
int FragmentsPerPartition = 1 ///< Used to coarsten the epilogue granularity
|
||||
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
||||
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
||||
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
|
||||
>
|
||||
class Epilogue :
|
||||
public EpilogueBase<
|
||||
@ -141,8 +144,8 @@ public:
|
||||
/// Number of warps
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
int const kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
||||
int const kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
||||
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
|
||||
public:
|
||||
|
||||
@ -194,8 +197,52 @@ public:
|
||||
|
||||
private:
|
||||
|
||||
template <class Seq>
|
||||
struct acc2smem_source_not_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
|
||||
template <int Advance>
|
||||
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
||||
(1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {
|
||||
(pos == (Seq * Base::kFragmentsPerIteration)) &&
|
||||
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
|
||||
CUTLASS_UNUSED(dummy[0]);
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_not_needed_(
|
||||
@ -214,7 +261,7 @@ private:
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
|
||||
|
||||
//
|
||||
@ -224,23 +271,11 @@ private:
|
||||
__syncthreads();
|
||||
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
this->warp_tile_iterator_.store(accum_fragment);
|
||||
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
acc2smem_source_not_needed<
|
||||
cutlass::make_index_sequence<OutputTileIterator::kIterations /
|
||||
Base::kFragmentsPerIteration>>::push(iter,
|
||||
accum_fragment_iterator,
|
||||
this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -295,7 +330,34 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<class Seq>
|
||||
struct acc2smem_source_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
||||
template<int Advance>
|
||||
CUTLASS_DEVICE
|
||||
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
@ -319,7 +381,7 @@ private:
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
|
||||
//
|
||||
@ -335,12 +397,8 @@ private:
|
||||
|
||||
__syncthreads();
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
this->warp_tile_iterator_.store(accum_fragment);
|
||||
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@ -32,6 +32,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
@ -59,6 +62,32 @@ namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// This is used for metaprogramming epilogue functors. If they define
|
||||
// `static bool const kIsHeavy = true;`, then the epilogue functor itself is
|
||||
// not inlined. This results in smaller code and is advantageous if the epilogue
|
||||
// functor consists of many instructions.
|
||||
//
|
||||
// If the epilogue functor does not define `kIsHeavy` or if it is `false`, then
|
||||
// the behavior from CUTLASS 2.5 and before is retained. The epilogue is fully
|
||||
// unrolled and inlined.
|
||||
//
|
||||
|
||||
template<class>
|
||||
struct TypeSink { typedef void type; };
|
||||
|
||||
template<class T> using TypeSinkT = typename TypeSink<T>::type;
|
||||
|
||||
template<class T, class=void> struct IsEpilogueFunctorHeavy {
|
||||
static bool const value = false;
|
||||
};
|
||||
|
||||
template<class T> struct IsEpilogueFunctorHeavy<T, TypeSinkT< decltype( T::kIsHeavy ) > > {
|
||||
static bool const value = T::kIsHeavy;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Base class for epilogues defining warp-level
|
||||
template <
|
||||
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
|
||||
207
include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h
Normal file
207
include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h
Normal file
@ -0,0 +1,207 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/util/index_sequence.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue operator
|
||||
template <
|
||||
typename ElementAccumulator_,
|
||||
typename ElementOutput_,
|
||||
typename ThreadBlockShape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
bool ReduceKForA_
|
||||
>
|
||||
class EpilogueGemmKReduction {
|
||||
|
||||
public:
|
||||
|
||||
using ThreadBlockShape = ThreadBlockShape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
using WarpShape = typename WarpMmaOperator::Shape;
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = ElementOutput_;
|
||||
|
||||
/// Output access size
|
||||
static int const kElementsPerAccess = 1;
|
||||
|
||||
static bool const kReduceKForA = ReduceKForA_;
|
||||
|
||||
static int const kThreadBlockSize = kReduceKForA ? ThreadBlockShape::kM : ThreadBlockShape::kN;
|
||||
|
||||
static int const kWarpSize = kReduceKForA ? WarpShape::kM : WarpShape::kN;
|
||||
|
||||
static int const kIterations = kWarpSize / 8;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kIterations>;
|
||||
|
||||
private:
|
||||
|
||||
int thread_offset_;
|
||||
ElementOutput* pointer_;
|
||||
int col_;
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueGemmKReduction(
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx, ///< Id of thread within warp
|
||||
int threadblock_offset,
|
||||
ElementOutput* pointer
|
||||
)
|
||||
{
|
||||
col_ = lane_idx % 4;
|
||||
thread_offset_ = threadblock_offset * kThreadBlockSize
|
||||
+ warp_idx * kWarpSize
|
||||
+ lane_idx / 4 + col_ * 8;
|
||||
|
||||
pointer_ = pointer + LongIndex(thread_offset_);
|
||||
}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int size,
|
||||
FragmentAccumulator &gemm_k_with_reduction_accumulation,
|
||||
bool LoadForSerialSplitK
|
||||
) {
|
||||
bool guard[kIterations / 4];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations / 4; ++i) {
|
||||
guard[i] = ((thread_offset_ + i * 32) < size);
|
||||
}
|
||||
|
||||
Array<ElementOutput, kIterations / 4> source;
|
||||
source.clear();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations / 4; ++i) {
|
||||
ElementOutput tmp;
|
||||
cutlass::arch::global_load<ElementOutput, sizeof(ElementOutput)>(
|
||||
tmp,
|
||||
(void *)(pointer_ + i * 32),
|
||||
guard[i] && LoadForSerialSplitK);
|
||||
|
||||
source[i] = tmp;
|
||||
}
|
||||
|
||||
FragmentAccumulator sum = gemm_k_with_reduction_accumulation;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 1);
|
||||
sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 2);
|
||||
}
|
||||
|
||||
Array<ElementAccumulator, kIterations / 4> intermediate;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations / 4; ++i) {
|
||||
if (col_ == 0) {
|
||||
intermediate[i] = sum[0 + i * 4];
|
||||
}
|
||||
|
||||
if (col_ == 1) {
|
||||
intermediate[i] = sum[1 + i * 4];
|
||||
}
|
||||
|
||||
if (col_ == 2) {
|
||||
intermediate[i] = sum[2 + i * 4];
|
||||
}
|
||||
|
||||
if (col_ == 3) {
|
||||
intermediate[i] = sum[3 + i * 4];
|
||||
}
|
||||
}
|
||||
|
||||
NumericArrayConverter<ElementAccumulator, ElementOutput, kIterations / 4> source_converter;
|
||||
Array<ElementAccumulator, kIterations / 4> converted_source = source_converter(source);
|
||||
|
||||
plus<Array<ElementAccumulator, kIterations / 4>> plus_source;
|
||||
intermediate = plus_source(intermediate, converted_source);
|
||||
|
||||
NumericArrayConverter<ElementOutput, ElementAccumulator, kIterations / 4> converter;
|
||||
Array<ElementOutput, kIterations / 4> result = converter(intermediate);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations / 4; ++i) {
|
||||
cutlass::arch::global_store<ElementOutput, sizeof(ElementOutput)>(result[i],
|
||||
(void *)(pointer_ + i * 32), guard[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
817
include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h
Normal file
817
include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h
Normal file
@ -0,0 +1,817 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
#include "cutlass/util/index_sequence.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This base class is meant to define the concept required of the
|
||||
/// EpilogueWithBroadcast::OutputOp
|
||||
template <
|
||||
typename ElementC_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementCompute_,
|
||||
typename ElementZ_,
|
||||
typename ElementT_,
|
||||
int ElementsPerAccess,
|
||||
bool StoreZ = true,
|
||||
bool StoreT = true
|
||||
>
|
||||
struct EpilogueWithBroadcastOpBase {
|
||||
|
||||
using ElementOutput = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementZ = ElementZ_;
|
||||
using ElementT = ElementT_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
|
||||
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
||||
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
||||
|
||||
/// If true, the 'Z' tensor is stored
|
||||
static bool const kStoreZ = StoreZ;
|
||||
|
||||
/// If true, the 'T' tensor is stored
|
||||
static bool const kStoreT = StoreT;
|
||||
|
||||
/// Parameters structure - required
|
||||
struct Params { };
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor from Params
|
||||
EpilogueWithBroadcastOpBase(Params const ¶ms_) { }
|
||||
|
||||
/// Determine if the source is needed. May return false if
|
||||
bool is_source_needed() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) { }
|
||||
|
||||
/// Applies the operation when is_source_needed() is true
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentC const &frag_C,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
}
|
||||
|
||||
/// Applies the operation when is_source_needed() is false
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue operator with bias vector broadcast over columns.
|
||||
///
|
||||
/// Computes the following:
|
||||
///
|
||||
///
|
||||
/// Z, T = OutputOp(AB, C, Broadcast)
|
||||
///
|
||||
/// if (ElementwiseOp::kStoreZ) {
|
||||
/// store(converted_u);
|
||||
/// }
|
||||
///
|
||||
/// if (ElementwiseOp::kStoreT) {
|
||||
/// store(v);
|
||||
/// }
|
||||
///
|
||||
template <
|
||||
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
int PartitionsK, ///< Number of partitions of the K dimension
|
||||
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
|
||||
typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
|
||||
typename ElementVector_, ///< Pointer to broadcast vector
|
||||
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
||||
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
||||
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
||||
typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
|
||||
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
||||
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
||||
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
||||
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
|
||||
>
|
||||
class EpilogueWithBroadcast :
|
||||
public EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition> {
|
||||
|
||||
public:
|
||||
|
||||
using Base = EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition>;
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using TensorTileIterator = TensorTileIterator_;
|
||||
using ElementVector = ElementVector_;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using WarpTileIterator = WarpTileIterator_;
|
||||
using SharedLoadIterator = SharedLoadIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
using Padding = Padding_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename Base::AccumulatorTile;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename WarpTileIterator::Element;
|
||||
|
||||
/// Compute data type produced by the output op
|
||||
using ElementCompute = typename OutputOp::ElementCompute;
|
||||
|
||||
/// Compute fragment
|
||||
using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementCompute,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Data type of additional tensor
|
||||
using ElementTensor = typename TensorTileIterator::Element;
|
||||
|
||||
/// Output access size
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Tensor reference to destination tensor
|
||||
using TensorRef = typename OutputTileIterator::TensorRef;
|
||||
|
||||
/// Tensor reference to sync tensor
|
||||
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
||||
|
||||
/// Const tensor reference to source tensor
|
||||
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
||||
|
||||
/// Array type used to output
|
||||
using OutputAccessType = Array<
|
||||
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Tensor access type
|
||||
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
/// Shared memory allocation from epilogue base class
|
||||
using BaseSharedStorage = typename Base::SharedStorage;
|
||||
|
||||
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
||||
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = kWarpSize * WarpCount::kCount;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
/// I'm not sure what I meant here.
|
||||
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
|
||||
/// Shape of the shared memory allocation for the epilogue
|
||||
using StorageShape = MatrixShape<
|
||||
kThreadRows,
|
||||
Shape::kN
|
||||
>;
|
||||
|
||||
/// Debug printing
|
||||
CUTLASS_DEVICE
|
||||
static void print() {
|
||||
printf("BroadcastDetail {\n");
|
||||
printf(
|
||||
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
|
||||
"kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
|
||||
kColumnsPerThread,
|
||||
kRowsPerThread,
|
||||
kThreadCount,
|
||||
kThreadsPerRow,
|
||||
kThreadRows,
|
||||
kThreadAccessesPerRow,
|
||||
StorageShape::kRow,
|
||||
StorageShape::kColumn,
|
||||
StorageShape::kCount
|
||||
);
|
||||
printf("};\n");
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
||||
struct SharedStorage {
|
||||
union {
|
||||
BaseSharedStorage base;
|
||||
};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
|
||||
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
||||
"Mismatch between shared load iterator and output tile iterator.");
|
||||
|
||||
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
||||
|
||||
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
private:
|
||||
|
||||
/// Loads fragment from shared memory aligned with output tensor
|
||||
SharedLoadIterator shared_load_iterator_;
|
||||
|
||||
/// Thread index within the threadblock
|
||||
int thread_idx_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueWithBroadcast(
|
||||
SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
):
|
||||
Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
|
||||
shared_load_iterator_(shared_storage.base.reference(), thread_idx),
|
||||
thread_idx_(thread_idx)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
|
||||
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
||||
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
||||
MatrixCoord(Shape::kM, Shape::kN),
|
||||
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
||||
MatrixCoord()) {
|
||||
|
||||
BroadcastFragment broadcast_fragment;
|
||||
|
||||
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
||||
|
||||
if (!output_op.is_source_needed()) {
|
||||
compute_source_not_needed_(
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
tensor_iterator);
|
||||
}
|
||||
else {
|
||||
compute_source_needed_(
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
source_iterator,
|
||||
tensor_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_broadcast_fragment_(
|
||||
BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
|
||||
MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
|
||||
) {
|
||||
|
||||
broadcast_fragment.clear();
|
||||
|
||||
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
||||
if (!broadcast_ptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
||||
|
||||
int thread_column_idx = threadblock_offset.column() + thread_initial_column;
|
||||
broadcast_ptr += thread_initial_column;
|
||||
|
||||
NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
||||
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
||||
|
||||
AccessType loaded;
|
||||
|
||||
loaded.clear();
|
||||
|
||||
if (thread_column_idx < problem_size.column()) {
|
||||
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
||||
}
|
||||
|
||||
ComputeFragmentType cvt = converter(loaded);
|
||||
frag_ptr[j] = cvt;
|
||||
|
||||
thread_column_idx += ThreadMap::Delta::kColumn;
|
||||
broadcast_ptr += ThreadMap::Delta::kColumn;
|
||||
}
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
struct acc2smem_source_not_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
|
||||
template <int Advance>
|
||||
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
||||
(1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {
|
||||
(pos == (Seq * Base::kFragmentsPerIteration)) &&
|
||||
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
|
||||
CUTLASS_UNUSED(dummy[0]);
|
||||
}
|
||||
};
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_not_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
// CUTLASS_PRAGMA_UNROLL
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_not_needed<
|
||||
cutlass::make_index_sequence<OutputTileIterator::kIterations /
|
||||
Base::kFragmentsPerIteration>>::push(iter,
|
||||
accum_fragment_iterator,
|
||||
this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
else if (kPartitionsK > 1) {
|
||||
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
||||
}
|
||||
|
||||
//
|
||||
// Apply output operation
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment frag_Z;
|
||||
typename TensorTileIterator::Fragment frag_T;
|
||||
|
||||
apply_output_operator_source_not_needed_(
|
||||
frag_Z,
|
||||
frag_T,
|
||||
output_op,
|
||||
aligned_accum_fragment[0],
|
||||
broadcast_fragment);
|
||||
|
||||
//
|
||||
// Conditionally store fragments
|
||||
//
|
||||
|
||||
if (OutputOp::kStoreZ) {
|
||||
destination_iterator.store(frag_Z);
|
||||
++destination_iterator;
|
||||
}
|
||||
|
||||
if (OutputOp::kStoreT) {
|
||||
tensor_iterator.store(frag_T);
|
||||
++tensor_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<class Seq>
|
||||
struct acc2smem_source_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
||||
template<int Advance>
|
||||
CUTLASS_DEVICE
|
||||
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
source_fragment.clear();
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
|
||||
//
|
||||
// Load the source
|
||||
//
|
||||
|
||||
source_iterator.load(source_fragment);
|
||||
++source_iterator;
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
||||
if (kPartitionsK > 1)
|
||||
{
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
|
||||
}
|
||||
|
||||
//
|
||||
// Apply output operation
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment frag_Z;
|
||||
typename TensorTileIterator::Fragment frag_T;
|
||||
|
||||
apply_output_operator_(
|
||||
frag_Z,
|
||||
frag_T,
|
||||
output_op,
|
||||
aligned_accum_fragment[0],
|
||||
source_fragment,
|
||||
broadcast_fragment);
|
||||
|
||||
//
|
||||
// Conditionally store fragments
|
||||
//
|
||||
|
||||
if (OutputOp::kStoreZ) {
|
||||
destination_iterator.store(frag_Z);
|
||||
++destination_iterator;
|
||||
}
|
||||
|
||||
if (OutputOp::kStoreT) {
|
||||
tensor_iterator.store(frag_T);
|
||||
++tensor_iterator;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_(
|
||||
typename OutputTileIterator::Fragment &frag_Z,
|
||||
typename TensorTileIterator::Fragment &frag_T,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &frag_AB,
|
||||
typename OutputTileIterator::Fragment const &frag_C,
|
||||
BroadcastFragment const &frag_Broadcast) {
|
||||
|
||||
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
||||
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
OutputAccessType const *frag_C_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_not_needed_(
|
||||
typename OutputTileIterator::Fragment &frag_Z,
|
||||
typename TensorTileIterator::Fragment &frag_T,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &frag_AB,
|
||||
BroadcastFragment const &frag_Broadcast) {
|
||||
|
||||
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
||||
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user