Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5fe09c2d67 | |||
| 6b69c79ac3 | |||
| 62e438f450 | |||
| 808c25337a | |||
| 6fc5008803 | |||
| a3bcc6981d | |||
| 3b28642801 | |||
| 538592dea4 | |||
| 2e07c4cc2f | |||
| 9ac255863f | |||
| 59e2aa505a | |||
| 4e8af93da1 | |||
| 6c2f8f2fb8 | |||
| 598e35401c | |||
| f4b0a33633 | |||
| bb35a3ba6f | |||
| 7ec3a87f22 |
45
CHANGELOG.md
45
CHANGELOG.md
@ -1,6 +1,46 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
# CUTLASS 2.x
|
||||
## [2.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.8.0) (2021-11-19)
|
||||
|
||||
* **TF32x3:** emulated single-precision using Tensor Cores
|
||||
* 45+ TFLOPs on NVIDIA A100
|
||||
* [GEMM SDK example](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
|
||||
* [COMPLEX GEMM SDK example](/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu) (complex)
|
||||
* [Implicit GEMM Convolution SDK example](/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
|
||||
* **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu
|
||||
* [Conv Fprop SDK example](/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
|
||||
* [Conv WGrad SDK example](/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
|
||||
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
|
||||
* **Grouped GEMM:** similar to batched GEMM with distinct problem size per group
|
||||
* [SDK example](/examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
|
||||
* [cutlass::gemm::device::GemmGrouped](/include/cutlass/gemm/device/gemm_grouped.h)
|
||||
* [Implicit GEMM Convolution fusion](/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
|
||||
* Optimal performance using [**CUDA 11.5**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates from the community (thanks!)
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
* Maxwell and Pascal GPU architectures
|
||||
* Ubuntu 16.04
|
||||
* CUDA 10.2
|
||||
|
||||
|
||||
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
* [Half-precision GELU_taylor activation functions](/include/cutlass/epilogue/thread/activation.h#L196)
|
||||
* Use these when accumulation and epilogue compute types are all `cutlass::half_t`
|
||||
* Tuning and bug fixes to [fused GEMM + GEMM example](/examples/13_two_tensor_op_fusion/)
|
||||
* Support for smaller than 128b aligned Convolutions: [see examples](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
|
||||
* Caching of results to accelerate Convolution [unit tests](test/unit/conv/device/cache_testbed_output.h)
|
||||
* Can be enabled or disabled by running `cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF`
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
## [2.6.1](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.1) (2021-09-03)
|
||||
* Arbitrary padding and striding for CUTLASS Strided DGRAD Convolution operator (Analytic Iterators)
|
||||
* Tuning for GEMMs fused with partial reductions
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
|
||||
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
@ -23,7 +63,8 @@
|
||||
* Many improvements to the epilogue.
|
||||
* Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Performance improvement for FP16 tensor core kernels
|
||||
* Bug fixes
|
||||
* Bug fixes
|
||||
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
|
||||
* Updated minimum CUDA Toolkit requirement to 10.2
|
||||
* [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
|
||||
@ -32,7 +32,7 @@ endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.6.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.7.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 10.2)
|
||||
@ -168,6 +168,11 @@ if (${CUTLASS_NVCC_VERBOSE})
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v)
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS NAMESPACE
|
||||
#
|
||||
set(CUTLASS_NAMESPACE "cutlass" CACHE STRING "Top level namespace of CUTLASS")
|
||||
|
||||
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
|
||||
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
@ -183,10 +188,18 @@ set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kern
|
||||
|
||||
# Test Levels L0, L1, L2
|
||||
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
|
||||
|
||||
|
||||
set(CUTLASS_TEST_ENABLE_CACHED_RESULTS ON CACHE BOOL "Enable caching and reuse of test results in unit tests")
|
||||
|
||||
set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
|
||||
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
|
||||
#
|
||||
@ -239,7 +252,7 @@ if (NOT MSVC AND CUTLASS_NVCC_KEEP)
|
||||
# MSVC flow handles caching already, but for other generators we handle it here.
|
||||
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
|
||||
file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR})
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep) # --keep-dir may not work with nvcc for some directories.
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v) # --keep-dir may not work with nvcc for some directories.
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR})
|
||||
endif()
|
||||
|
||||
@ -383,6 +396,8 @@ function(cutlass_apply_standard_compile_options TARGET)
|
||||
set(_FLAGS_DEBUG ${__CUTLASS_CUDA_FLAGS_DEBUG} ${__CUTLASS_CUDA_NVCC_FLAGS_DEBUG})
|
||||
endif()
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE CUTLASS)
|
||||
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
@ -425,6 +440,7 @@ set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/includ
|
||||
include_directories(${CUTLASS_INCLUDE_DIR})
|
||||
|
||||
target_compile_features(CUTLASS INTERFACE cxx_std_11)
|
||||
target_compile_definitions(CUTLASS INTERFACE CUTLASS_NAMESPACE=${CUTLASS_NAMESPACE})
|
||||
|
||||
if (NOT DEFINED CUTLASS_REVISION)
|
||||
|
||||
@ -564,10 +580,12 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
# TEST_COMMAND_OPTIONS: A list of variables (i.e. by reference params) which contain command line arguments
|
||||
# to pass to the test executable. A unique test with suffix _0, _1, ... is generated for each set of
|
||||
# options given. If this option is not used, a single test with no arguments is generated.
|
||||
# RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed
|
||||
# test results to speed up test runtime.
|
||||
#
|
||||
|
||||
set(options DISABLE_EXECUTABLE_INSTALL_RULE)
|
||||
set(oneValueArgs DISABLE_TESTS)
|
||||
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE)
|
||||
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
@ -575,6 +593,17 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
set(__DISABLE_TESTS OFF)
|
||||
endif()
|
||||
|
||||
if (__RESULT_CACHE_FILE)
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND}
|
||||
ARGS -E copy ${__RESULT_CACHE_FILE} "$<TARGET_FILE_DIR:${TARGET}>"
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
if (NOT __DISABLE_EXECUTABLE_INSTALL_RULE AND CUTLASS_INSTALL_TESTS)
|
||||
|
||||
# file(RELATIVE_PATH CMAKE_CURRENT_BINARY_RELATIVE_DIR ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR})
|
||||
@ -583,6 +612,15 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
TARGETS ${TARGET}
|
||||
RUNTIME DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}
|
||||
)
|
||||
|
||||
if (__RESULT_CACHE_FILE)
|
||||
|
||||
install(
|
||||
FILES ${__RESULT_CACHE_FILE}
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}/
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
@ -24,6 +24,9 @@ Markus Hohnerbach
|
||||
Aditya Atluri
|
||||
David Tanner
|
||||
Manikandan Ananth
|
||||
|
||||
## CUTLASS Product Manager
|
||||
Matthew Nicely
|
||||
|
||||
## CONTRIBUTORS
|
||||
Timothy Costa
|
||||
@ -56,7 +59,6 @@ Olivier Giroux
|
||||
Stephen Jones
|
||||
Rishkul Kulkarni
|
||||
Bryce Lelbach
|
||||
Matthew Nicely
|
||||
Joel McCormack
|
||||
Kyrylo Perelygin
|
||||
|
||||
|
||||
25
CUDA.cmake
25
CUDA.cmake
@ -74,7 +74,7 @@ find_library(
|
||||
lib64
|
||||
lib
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
@ -89,10 +89,10 @@ if(NOT TARGET cudart AND CUDART_LIBRARY)
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(cudart SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cudart ALIAS cudart)
|
||||
|
||||
|
||||
set_property(
|
||||
TARGET cudart
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
@ -120,7 +120,7 @@ find_library(
|
||||
lib64/stubs
|
||||
lib/stubs
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
@ -135,10 +135,10 @@ if(NOT TARGET cuda_driver AND CUDA_DRIVER_LIBRARY)
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(cuda_driver SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cuda_driver ALIAS cuda_driver)
|
||||
|
||||
|
||||
set_property(
|
||||
TARGET cuda_driver
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
@ -164,7 +164,7 @@ find_library(
|
||||
lib64
|
||||
lib
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
@ -179,10 +179,10 @@ if(NOT TARGET nvrtc AND NVRTC_LIBRARY)
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(nvrtc SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
add_library(nvidia::nvrtc ALIAS nvrtc)
|
||||
|
||||
|
||||
set_property(
|
||||
TARGET nvrtc
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
@ -242,7 +242,7 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
|
||||
set(CUDA_FILE_ARGS)
|
||||
set(TARGET_SOURCE_ARGS)
|
||||
|
||||
|
||||
foreach(ARG ${__UNPARSED_ARGUMENTS})
|
||||
if(${ARG} MATCHES ".*\.cu$")
|
||||
list(APPEND CUDA_FILE_ARGS ${ARG})
|
||||
@ -250,7 +250,7 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
list(APPEND TARGET_SOURCE_ARGS ${ARG})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
|
||||
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
|
||||
while(NUM_CUDA_FILE_ARGS GREATER 0)
|
||||
list(SUBLIST CUDA_FILE_ARGS 0 ${__BATCH_SIZE} CUDA_FILE_BATCH)
|
||||
@ -280,7 +280,6 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
set(${TARGET_ARGS_VAR} ${TARGET_SOURCE_ARGS} PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_add_library NAME)
|
||||
|
||||
set(options)
|
||||
|
||||
111
README.md
111
README.md
@ -1,15 +1,15 @@
|
||||

|
||||
|
||||
# CUTLASS 2.6
|
||||
# CUTLASS 2.8
|
||||
|
||||
_CUTLASS 2.6 - July 2021_
|
||||
_CUTLASS 2.8 - November 2021_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical decomposition and data movement similar
|
||||
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
|
||||
reusable, modular software components abstracted by C++ template classes. These
|
||||
thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
|
||||
high-performance matrix-multiplication (GEMM) and related computations at all levels
|
||||
and scales within CUDA. It incorporates strategies for hierarchical decomposition and
|
||||
data movement similar to those used to implement cuBLAS and cuDNN. CUTLASS decomposes
|
||||
these "moving parts" into reusable, modular software components abstracted by C++ template
|
||||
classes. These thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
|
||||
and tuned via custom tiling sizes, data types, and other algorithmic policy. The
|
||||
resulting flexibility simplifies their use as building blocks within custom kernels
|
||||
and applications.
|
||||
@ -20,82 +20,38 @@ multiply-accumulate abstractions for half-precision floating
|
||||
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
|
||||
single-precision floating point (FP32), double-precision floating
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
|
||||
Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, and Ampere architectures.
|
||||
|
||||
Additionaly, CUTLASS implements high-performance convolution (implicit GEMM).
|
||||
Implicit GEMM is the formulation of a convolution operation as a GEMM. This allows CUTLASS
|
||||
to build convolutions by reusing highly optimized warp-wide GEMM components and below.
|
||||
CUTLASS implements high-performance Convolution via the implicit GEMM algorithm.
|
||||
Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of
|
||||
CUTLASS's modular GEMM pipeline.
|
||||
This allows CUTLASS to build convolutions by reusing highly optimized warp-wide GEMM components and below.
|
||||
|
||||
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.6
|
||||
CUTLASS 2.6 is a minor update to CUTLASS adding:
|
||||
- Fused [broadcast](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) and [reductions](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) in the epilogues of GEMM and Convolution
|
||||
- [Quaternion-valued GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_conv/quaternion_conv.cu) in single-precision
|
||||
- [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation offers up to 4x performance improvements over previous strided Dgrad
|
||||
- 64-bit strides for large tensor allocations
|
||||
- [General affine layouts](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) fp64 tensor core and simt GEMM
|
||||
- Enhanced functionality, boosted performance, and bug fixes in the epilogue.
|
||||
- Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
- Adopt new L2 prefetch feature in [ptx instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-isa-version-7-4).
|
||||
- Numerous updates from the community (thanks!)
|
||||
- See the [CHANGELOG](CHANGELOG.md) for more details
|
||||
# What's New in CUTLASS 2.8
|
||||
CUTLASS 2.8 is an update to CUTLASS adding:
|
||||
- [TF32x3:](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm) emulated single-precision using Tensor Cores; 45+ TFLOPs on NVIDIA A100
|
||||
- [Mainloop fusion for Convolution:](/examples/25_ampere_fprop_mainloop_fusion) convolution with fused per-channel bias-add
|
||||
- [Grouped GEMM:](/examples/24_gemm_grouped) similar to batched GEMM with distinct problem size per group
|
||||
- [Implicit GEMM Convolution fusion](/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing.
|
||||
- Optimal performance using [CUDA 11.5](https://developer.nvidia.com/cuda-downloads)
|
||||
- **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
- CUDA 10.2
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
|
||||
# What's New in CUTLASS 2.5
|
||||
CUTLASS 2.5 is a minor update to CUTLASS adding:
|
||||
- [Tensor reductions](/test/unit/reduction/device/tensor_reduce_contiguous.cu)
|
||||
- [Optimizations for 3-D convolution](include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h)
|
||||
- [Fused Convolution+Convolution example](/examples/13_two_tensor_op_fusion/README.md)
|
||||
|
||||
# What's New in CUTLASS 2.4
|
||||
CUTLASS 2.4 is a significant update to CUTLASS adding:
|
||||
- 1-D, 2-D, and 3-D convolution targeting Tensor and CUDA cores for NVIDIA Ampere, Turing, and Volta GPU architectures
|
||||
- CUTLASS profiler support for convolution
|
||||
- [Documentation](/media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
|
||||
# What's New in CUTLASS 2.3
|
||||
|
||||
CUTLASS 2.3 is a minor update to CUTLASS adding:
|
||||
- GEMMs targeting structured [Sparse Tensor Cores](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) in NVIDIA Ampere Architecture GPUs
|
||||
- Fast SGEMM kernels targeting GeForce RTX 30-series CUDA Cores
|
||||
- Intended to be compiled with [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit) or later
|
||||
|
||||
# What's New in CUTLASS 2.2
|
||||
|
||||
CUTLASS 2.2 is a significant update to CUTLASS adding:
|
||||
|
||||
- Coverage of [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types
|
||||
- Deep software pipelines using asynchronous copy
|
||||
- Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745)
|
||||
- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) or later
|
||||
|
||||
# What's New in CUTLASS 2.1
|
||||
|
||||
CUTLASS 2.1 is a minor update to CUTLASS adding:
|
||||
|
||||
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
|
||||
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
|
||||
# What's New in CUTLASS 2.0
|
||||
|
||||
CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
|
||||
|
||||
- Better performance over 1.x, particularly for kernels targeting Turing Tensor Cores
|
||||
- Robust and durable templates that reliably span the design space
|
||||
- Encapsulated functionality that may be reusable in other contexts
|
||||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for more details.**
|
||||
**See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.**
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=/media/images/cutlass-performance-plot.png></p>
|
||||
<p align="center"><img src=/media/images/cutlass-2.8-gemm-performance.png></p>
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit performance comparable to cuBLAS for scalar GEMM
|
||||
@ -107,8 +63,8 @@ using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++11 host compiler and
|
||||
performs best when built with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 10.2, CUDA 11.0, CUDA 11.1, CUDA 11.2, and CUDA 11.3.
|
||||
performs best when built with the [CUDA 11.5 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.0, CUDA 11.1, CUDA 11.2, CUDA 11.3, and CUDA 11.4.
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
@ -116,29 +72,26 @@ We have tested the following environments.
|
||||
|-----------------|----------|
|
||||
| Windows 10 | Microsoft Visual Studio 2015|
|
||||
| | Microsoft Visual Studio 2017|
|
||||
| Ubuntu 16.04 | GCC 5.4.0 |
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.2.0 |
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
|
||||
Additionally, CUTLASS may be built with clang.
|
||||
See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
|
||||
For all GPUs, we recommend compiling with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
For all GPUs, we recommend compiling with the [**CUDA 11.5 Toolkit**](https://developer.nvidia.com/cuda-toolkit)
|
||||
for best performance.
|
||||
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|---|---|---|---|
|
||||
|NVIDIA Tesla P100|6.0|9.2| |
|
||||
|NVIDIA GeForce 1080|6.1|9.2| |
|
||||
|NVIDIA TitanXP|6.1|9.2| |
|
||||
|NVIDIA Tesla V100|7.0|9.2|10.1|
|
||||
|NVIDIA TitanV|7.0|9.2|10.1|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|
||||
|NVIDIA Tesla T4|7.5|10.0|10.2|
|
||||
|NVIDIA A100|8.0|11.0|11.0|
|
||||
|NVIDIA A10 |8.6|11.1|11.1|
|
||||
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|
||||
|
||||
# Documentation
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
set(TEST_COMMAND_00 RowMajor --extent=16,16)
|
||||
set(TEST_COMMAND_01 "ColumnMajorInterleaved<4>" --extent=32,8 --output-shape=16 --vectorize=4)
|
||||
set(TEST_COMMAND_01 \"ColumnMajorInterleaved<4>\" --extent=32,8 --output-shape=16 --vectorize=4)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
03_visualize_layout
|
||||
|
||||
@ -48,169 +48,13 @@ cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 (
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 1, 1, 64}, // filter size (KRSC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{128, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 Analytic Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Analytic Convolution Fprops...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
@ -220,9 +64,9 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //use beta for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
@ -245,7 +89,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -269,7 +113,8 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -302,14 +147,14 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -326,10 +171,12 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
@ -348,14 +195,92 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75() {
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 256, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with RF Residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
|
||||
@ -55,159 +55,6 @@ cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 (
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_f16_sm80() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 Analytic Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_f16_sm80() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Analytic Convolution Fprops...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
|
||||
@ -48,169 +48,13 @@ cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 (
|
||||
{128, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 1, 1, 64}, // filter size (KRSC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
{128, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1, 32> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 interleaved Analytic Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Analytic Convolution Fprops...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
@ -222,7 +66,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
@ -304,12 +148,12 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -330,6 +174,7 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
@ -348,14 +193,91 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75() {
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 256, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
|
||||
@ -55,162 +55,6 @@ cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 (
|
||||
{128, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1, 32> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 interleaved Analytic Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Analytic Convolution Fprops...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
|
||||
@ -79,16 +79,19 @@ public:
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<typename Conv2d0::ElementA, typename Conv2d0::LayoutA> tensor_A0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
|
||||
|
||||
@ -99,9 +102,10 @@ public:
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) {
|
||||
|
||||
}
|
||||
|
||||
@ -138,37 +142,50 @@ public:
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
}
|
||||
}
|
||||
|
||||
void initialize(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1, uint64_t seed = 2019) {
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
||||
uint64_t seed = 2019) {
|
||||
|
||||
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_Bias0.resize({1, 1, 1, problem_size_0.K});
|
||||
tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
||||
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
|
||||
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
||||
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
||||
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
||||
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
||||
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
||||
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
||||
initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0_computed.sync_device();
|
||||
tensor_D0_reference.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1_computed.sync_device();
|
||||
tensor_D1_reference.sync_device();
|
||||
}
|
||||
@ -196,7 +213,7 @@ public:
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)},
|
||||
tensor_D0_computed.device_ref(),
|
||||
{alpha0, beta0},
|
||||
split_k_mode
|
||||
@ -205,7 +222,7 @@ public:
|
||||
problem_size_1,
|
||||
tensor_D0_computed.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)},
|
||||
tensor_D1_computed.device_ref(),
|
||||
{alpha1, beta1},
|
||||
split_k_mode
|
||||
@ -279,7 +296,7 @@ public:
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)},
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0);
|
||||
@ -302,7 +319,7 @@ public:
|
||||
problem_size_1,
|
||||
tensor_D0_reference.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)},
|
||||
tensor_D1_reference.device_ref(),
|
||||
alpha1,
|
||||
beta1);
|
||||
@ -344,10 +361,12 @@ public:
|
||||
<< "\nA0:\n" << tensor_A0.host_view() << "\n"
|
||||
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
||||
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n"
|
||||
<< "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n"
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
||||
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
||||
|
||||
@ -375,15 +394,20 @@ public:
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementA, typename B2bConv2d::LayoutA> tensor_A0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementCompute, typename B2bConv2d::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
|
||||
|
||||
@ -394,9 +418,12 @@ public:
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) {
|
||||
|
||||
}
|
||||
|
||||
@ -433,35 +460,56 @@ public:
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
}
|
||||
}
|
||||
|
||||
void initialize(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1, uint64_t seed = 2019) {
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
||||
ElementCompute alpha0,
|
||||
ElementCompute alpha1,
|
||||
uint64_t seed = 2019) {
|
||||
|
||||
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.K});
|
||||
tensor_Bias0.resize({1, problem_size_0.K});
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
||||
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
|
||||
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
||||
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
||||
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61);
|
||||
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
||||
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
||||
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
||||
initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0_reference.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1_computed.sync_device();
|
||||
tensor_D1_reference.sync_device();
|
||||
}
|
||||
@ -479,7 +527,7 @@ public:
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
initialize(problem_size_0, problem_size_1);
|
||||
initialize(problem_size_0, problem_size_1, alpha0, alpha1);
|
||||
|
||||
// configure the operator
|
||||
B2bConv2d b2b_conv2d_op;
|
||||
@ -490,15 +538,31 @@ public:
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)},
|
||||
tensor_D1_computed.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
split_k_mode
|
||||
);
|
||||
|
||||
cutlass::Status status = b2b_conv2d_op.initialize(b2b_conv2d_args);
|
||||
cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n"
|
||||
<< " problem_size_0.K = problem_size_1.C\n"
|
||||
<< " problem_size_1.R = problem_size_1.S = 1\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.K\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.K" << std::endl;
|
||||
}
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
status = b2b_conv2d_op.initialize(b2b_conv2d_args);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
@ -551,7 +615,10 @@ public:
|
||||
tensor_C0.device_ref(),
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0);
|
||||
beta0,
|
||||
nullptr, // stream
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref());
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
||||
@ -571,7 +638,7 @@ public:
|
||||
problem_size_1,
|
||||
tensor_D0_reference.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)},
|
||||
tensor_D1_reference.device_ref(),
|
||||
alpha1,
|
||||
beta1);
|
||||
@ -612,8 +679,11 @@ public:
|
||||
<< "\nA0:\n" << tensor_A0.host_view() << "\n"
|
||||
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
||||
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
||||
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
||||
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
||||
|
||||
|
||||
@ -80,18 +80,21 @@ public:
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<typename Conv2d0::ElementA, typename Conv2d0::LayoutA> tensor_A0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0_reordered;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1_reordered;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
|
||||
|
||||
@ -102,9 +105,10 @@ public:
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) {
|
||||
|
||||
}
|
||||
|
||||
@ -141,6 +145,12 @@ public:
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
}
|
||||
}
|
||||
@ -153,17 +163,20 @@ public:
|
||||
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_Bias0.resize({1, 1, 1, problem_size_0.K});
|
||||
tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
||||
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
|
||||
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
||||
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
||||
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
||||
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
||||
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
||||
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
||||
|
||||
@ -177,11 +190,13 @@ public:
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0_computed.sync_device();
|
||||
tensor_D0_reference.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1_computed.sync_device();
|
||||
tensor_D1_reference.sync_device();
|
||||
}
|
||||
@ -392,17 +407,22 @@ public:
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementA, typename B2bConv2d::LayoutA> tensor_A0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0_reordered;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1_reordered;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementCompute, typename B2bConv2d::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
|
||||
|
||||
@ -413,9 +433,12 @@ public:
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) {
|
||||
|
||||
}
|
||||
|
||||
@ -452,30 +475,47 @@ public:
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
}
|
||||
}
|
||||
|
||||
void initialize(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1, uint64_t seed = 2019) {
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
||||
ElementCompute alpha0,
|
||||
ElementCompute alpha1,
|
||||
uint64_t seed = 2019) {
|
||||
|
||||
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.K});
|
||||
tensor_Bias0.resize({1, problem_size_0.K});
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
||||
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
|
||||
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
||||
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
||||
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61);
|
||||
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
||||
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
||||
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
||||
initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
|
||||
|
||||
//Reorder B0 and B1
|
||||
cutlass::reorder_convK<16, InterleavedK>(
|
||||
@ -487,10 +527,14 @@ public:
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0_reference.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1_computed.sync_device();
|
||||
tensor_D1_reference.sync_device();
|
||||
}
|
||||
@ -508,7 +552,7 @@ public:
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
initialize(problem_size_0, problem_size_1);
|
||||
initialize(problem_size_0, problem_size_1, alpha0, alpha1);
|
||||
|
||||
// configure the operator
|
||||
B2bConv2d b2b_conv2d_op;
|
||||
@ -519,6 +563,8 @@ public:
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0_reordered.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref(),
|
||||
tensor_B1_reordered.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
tensor_D1_computed.device_ref(),
|
||||
@ -527,7 +573,21 @@ public:
|
||||
split_k_mode
|
||||
);
|
||||
|
||||
cutlass::Status status = b2b_conv2d_op.initialize(b2b_conv2d_args);
|
||||
cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n"
|
||||
<< " problem_size_0.K = problem_size_1.C\n"
|
||||
<< " problem_size_1.R = problem_size_1.S = 1\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.K\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.K" << std::endl;
|
||||
}
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
status = b2b_conv2d_op.initialize(b2b_conv2d_args);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
@ -581,7 +641,10 @@ public:
|
||||
tensor_C0.device_ref(),
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0);
|
||||
beta0,
|
||||
nullptr, // stream
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref());
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
||||
@ -644,6 +707,8 @@ public:
|
||||
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
||||
<< "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n"
|
||||
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
||||
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
|
||||
@ -390,14 +390,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<B2bGemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
@ -55,6 +55,8 @@ public:
|
||||
using LayoutC = typename B2bImplicitGemmKernel::LayoutC;
|
||||
using ElementAccumulator = typename B2bImplicitGemmKernel::ElementAccumulator;
|
||||
using ElementCompute = typename B2bImplicitGemmKernel::ElementCompute;
|
||||
using ElementScaleBias = typename B2bImplicitGemmKernel::ElementScaleBias;
|
||||
using LayoutScaleBias = typename B2bImplicitGemmKernel::LayoutScaleBias;
|
||||
using OperatorClass = typename B2bImplicitGemmKernel::OperatorClass;
|
||||
using ArchTag = typename B2bImplicitGemmKernel::ArchTag;
|
||||
using ThreadblockShape0 = typename B2bImplicitGemmKernel::ThreadblockShape0;
|
||||
@ -126,6 +128,26 @@ public:
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// Determine if fusion sizes are valid
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0);
|
||||
cutlass::gemm::GemmCoord problem_size_1 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1);
|
||||
|
||||
if(problem_size_0.m() != problem_size_1.m())
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_0.n() != problem_size_1.k())
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(args.problem_size_1.R != 1 || args.problem_size_1.S != 1)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_0.n() > ThreadblockShape0::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_1.n() > ThreadblockShape1::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -197,14 +219,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
cutlass::Kernel<B2bImplicitGemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
@ -217,6 +231,8 @@ public:
|
||||
params_.ptr_A0 = args.ref_A0.data();
|
||||
params_.ptr_B0 = args.ref_B0.data();
|
||||
params_.ptr_C0 = args.ref_C0.data();
|
||||
params_.ptr_Scale0 = args.ref_Scale0.data();
|
||||
params_.ptr_Bias0 = args.ref_Bias0.data();
|
||||
params_.ptr_B1 = args.ref_B1.data();
|
||||
params_.ptr_C1 = args.ref_C1.data();
|
||||
params_.ptr_D1 = args.ref_D1.data();
|
||||
|
||||
@ -60,8 +60,10 @@ int run_sm75() {
|
||||
std::cout << "Running on SM75" << std::endl;
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_f16_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_f16_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_f16_sm75_rf_res();
|
||||
pass &= run_nonfused_conv2d_fprop_optimized_s8_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_s8_sm75();
|
||||
pass &= run_fused_conv2d_fprop_optimized_s8_sm75_rf_res();
|
||||
|
||||
if(pass)
|
||||
return 1;
|
||||
|
||||
@ -79,6 +79,10 @@ struct B2bImplicitGemmConvolution {
|
||||
using ElementAccumulator = typename EpilogueOutputOp0::ElementAccumulator;
|
||||
using ElementCompute = typename EpilogueOutputOp0::ElementCompute;
|
||||
|
||||
/// Scale and Bias
|
||||
using ElementScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
||||
using LayoutScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Layout;
|
||||
|
||||
using WarpMmaOperator0 = typename B2bMma::Policy0::Operator;
|
||||
using WarpMmaOperator1 = typename B2bMma::Policy1::Operator;
|
||||
|
||||
@ -103,13 +107,14 @@ struct B2bImplicitGemmConvolution {
|
||||
|
||||
using TensorRefA0 = typename B2bMma::IteratorA0::TensorRef;
|
||||
using TensorRefB0 = typename B2bMma::IteratorB0::TensorRef;
|
||||
using TensorRefScaleBias0 = typename B2bMma::IteratorAccumulatorScaleBias::TensorRef;
|
||||
using TensorRefB1 = typename B2bMma::IteratorB1::TensorRef;
|
||||
using TensorRefC = cutlass::TensorRef<ElementC, LayoutC>;
|
||||
|
||||
/// Check iterator A and B convolution dimension are the same and
|
||||
// set device::B2bImplicitGemmConvolution::kConvDim
|
||||
static_assert(B2bMma::IteratorA0::kConvDim == B2bMma::IteratorB0::kConvDim,
|
||||
"Convolution on different different dimensions is not supported");
|
||||
"Convolution on different dimensions is not supported");
|
||||
static int const kConvDim = B2bMma::IteratorA0::kConvDim;
|
||||
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
@ -148,6 +153,8 @@ struct B2bImplicitGemmConvolution {
|
||||
TensorRefA0 ref_A0;
|
||||
TensorRefB0 ref_B0;
|
||||
TensorRefC ref_C0;
|
||||
TensorRefScaleBias0 ref_Scale0;
|
||||
TensorRefScaleBias0 ref_Bias0;
|
||||
TensorRefB1 ref_B1;
|
||||
TensorRefC ref_C1;
|
||||
TensorRefC ref_D1;
|
||||
@ -178,6 +185,8 @@ struct B2bImplicitGemmConvolution {
|
||||
TensorRefA0 const & ref_A0,
|
||||
TensorRefB0 const & ref_B0,
|
||||
TensorRefC const & ref_C0,
|
||||
TensorRefScaleBias0 const & ref_Scale0,
|
||||
TensorRefScaleBias0 const & ref_Bias0,
|
||||
TensorRefB1 const & ref_B1,
|
||||
TensorRefC const & ref_C1,
|
||||
TensorRefC const & ref_D1,
|
||||
@ -190,6 +199,8 @@ struct B2bImplicitGemmConvolution {
|
||||
ref_A0(ref_A0),
|
||||
ref_B0(ref_B0),
|
||||
ref_C0(ref_C0),
|
||||
ref_Scale0(ref_Scale0),
|
||||
ref_Bias0(ref_Bias0),
|
||||
ref_B1(ref_B1),
|
||||
ref_C1(ref_C1),
|
||||
ref_D1(ref_D1),
|
||||
@ -218,6 +229,8 @@ struct B2bImplicitGemmConvolution {
|
||||
typename B2bMma::IteratorB0::Element const *ptr_B0;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_C0;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Bias0;
|
||||
typename B2bMma::IteratorB1::Params iterator_B1;
|
||||
typename B2bMma::IteratorB1::Element const *ptr_B1;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_C1;
|
||||
@ -252,6 +265,8 @@ struct B2bImplicitGemmConvolution {
|
||||
ptr_B0(args.ref_B0.data()),
|
||||
iterator_C0(ConvOutputIteratorParameter::layout(args.ref_C0)),
|
||||
ptr_C0(args.ref_C0.data()),
|
||||
ptr_Scale0(args.ref_Scale0.data()),
|
||||
ptr_Bias0(args.ref_Bias0.data()),
|
||||
iterator_B1(args.problem_size_1, args.ref_B1.layout()),
|
||||
ptr_B1(args.ref_B1.data()),
|
||||
iterator_C1(ConvOutputIteratorParameter::layout(args.ref_C1)),
|
||||
@ -350,6 +365,28 @@ struct B2bImplicitGemmConvolution {
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
// Construct iterators to accumulator scale/bias vector
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
||||
params.ptr_Scale0,
|
||||
{1, params.problem_size_0.K},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
MatrixCoord(
|
||||
0, threadblock_tile_idx.n() * B2bMma::Shape0::kN
|
||||
)
|
||||
);
|
||||
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
||||
params.ptr_Bias0,
|
||||
{1, params.problem_size_0.K},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
MatrixCoord(
|
||||
0, threadblock_tile_idx.n() * B2bMma::Shape0::kN
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
@ -366,7 +403,8 @@ struct B2bImplicitGemmConvolution {
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, iterator_B1, src_accum, output_op_0);
|
||||
b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -77,6 +77,11 @@ template <
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// WarpIterator to load Scale or Bias vector from threadblock fragment
|
||||
typename FragmentIteratorA1ScaleBias_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
@ -117,6 +122,10 @@ public:
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< WarpIterator to load Scale or Bias vector from threadblock fragment
|
||||
using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
@ -126,6 +135,9 @@ public:
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
@ -143,6 +155,9 @@ public:
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
@ -193,6 +208,8 @@ public:
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
using WarpLoadedFragmentA1ScaleBias =
|
||||
typename FragmentIteratorA1ScaleBias::Fragment;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
@ -229,9 +246,9 @@ public:
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx)
|
||||
smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
@ -343,11 +360,15 @@ public:
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A operand in global memory
|
||||
///< iterator over A0 operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B operand in global memory
|
||||
///< iterator over B0 operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over B operand in global memory
|
||||
///< iterator over A1 operand scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_scale,
|
||||
///< iterator over A1 operand bias vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_bias,
|
||||
///< iterator over B1 operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
@ -571,6 +592,20 @@ public:
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
||||
FragmentA1ScaleBias tb_frag_A1_scale;
|
||||
FragmentA1ScaleBias tb_frag_A1_bias;
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale);
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias);
|
||||
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
|
||||
|
||||
//
|
||||
// Prologue
|
||||
@ -619,18 +654,29 @@ public:
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2];
|
||||
WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
if(PerChannelScale) {
|
||||
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]);
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
}
|
||||
warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]);
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0],
|
||||
warp_loaded_frag_A1_scale[0],
|
||||
warp_loaded_frag_A1_bias[0],
|
||||
output_op_0);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
@ -660,17 +706,40 @@ public:
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load threadblock-level scale/bias vector from global memory
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1) {
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
}
|
||||
|
||||
// Load warp-level scale bias fragment from threadblock scale/bias vector
|
||||
if(PerChannelScale) {
|
||||
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]);
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
}
|
||||
warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]);
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
|
||||
// Load warp-level tile from accumulator fragment
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2],
|
||||
output_op_0);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
|
||||
@ -70,6 +70,12 @@ template <
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// FragmentIterator to load Scale or Bias vector from threadblock fragment
|
||||
typename FragmentIteratorA1ScaleBias_,
|
||||
// (concept: VectorFragmentIterator)
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
@ -92,13 +98,13 @@ template <
|
||||
typename IteratorA0_::Element,
|
||||
IteratorA0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to A operand
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB0_ = NumericArrayConverter<
|
||||
typename SmemIteratorB0_::Element,
|
||||
typename IteratorB0_::Element,
|
||||
IteratorB0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to A operand
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB1_ = NumericArrayConverter<
|
||||
typename SmemIteratorB1_::Element,
|
||||
typename IteratorB1_::Element,
|
||||
@ -106,7 +112,8 @@ template <
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bImplicitGemmPipelined : public gemm::threadblock::B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
class B2bImplicitGemmPipelined :
|
||||
public gemm::threadblock::B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
@ -121,17 +128,24 @@ public:
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over tiles of A operand in global memory
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over tiles of A1 operand from accumulator tile
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using FragmentIteratorA1ScaleBias =
|
||||
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
using TransformB0 = TransformB0_;
|
||||
using TransformB1 = TransformB1_;
|
||||
@ -152,6 +166,9 @@ public:
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
|
||||
@ -182,6 +199,9 @@ private:
|
||||
using WarpFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
/// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment
|
||||
using WarpFragmentA1ScaleBias =
|
||||
typename FragmentIteratorA1ScaleBias::Fragment;
|
||||
using WarpFragmentB1 = typename Operator1::FragmentB;
|
||||
|
||||
protected:
|
||||
@ -206,9 +226,9 @@ public:
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx) {
|
||||
smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
@ -240,9 +260,11 @@ public:
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumulator tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
@ -370,18 +392,33 @@ public:
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
FragmentA1ScaleBias tb_frag_A1_scale;
|
||||
FragmentA1ScaleBias tb_frag_A1_bias;
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale);
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias);
|
||||
FragmentB1 tb_frag_B1;
|
||||
|
||||
if(PerChannelScale)
|
||||
tb_frag_A1_scale.clear();
|
||||
tb_frag_A1_bias.clear();
|
||||
tb_frag_B1.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
if(PerChannelScale)
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
|
||||
if(PerChannelScale)
|
||||
++iterator_A1_scale;
|
||||
++iterator_A1_bias;
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
@ -391,15 +428,24 @@ public:
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA1ScaleBias warp_frag_A1_scale[2];
|
||||
WarpFragmentA1ScaleBias warp_frag_A1_bias[2];
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
|
||||
if(PerChannelScale)
|
||||
warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]);
|
||||
warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]);
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0],
|
||||
warp_frag_A1_bias[0], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
if(PerChannelScale)
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
Operator1 warp_mma1;
|
||||
@ -447,13 +493,31 @@ public:
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
if(PerChannelScale)
|
||||
warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]);
|
||||
warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]);
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_frag_A1_scale[(warp_mma_k + 1) % 2],
|
||||
warp_frag_A1_bias[(warp_mma_k + 1) % 2],
|
||||
output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if(PerChannelScale)
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
@ -0,0 +1,532 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Transformation applied to A operand
|
||||
typename TransformA0_ = NumericArrayConverter<
|
||||
typename SmemIteratorA0_::Element,
|
||||
typename IteratorA0_::Element,
|
||||
IteratorA0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB0_ = NumericArrayConverter<
|
||||
typename SmemIteratorB0_::Element,
|
||||
typename IteratorB0_::Element,
|
||||
IteratorB0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB1_ = NumericArrayConverter<
|
||||
typename SmemIteratorB1_::Element,
|
||||
typename IteratorB1_::Element,
|
||||
IteratorB1_::Fragment::kElements>,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bImplicitGemmPipelinedSmemAccumulator :
|
||||
public gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, 2>;
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using Policy0 = Policy0_; ///< Policy0 describing tuning details
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
using TransformB0 = TransformB0_;
|
||||
using TransformB1 = TransformB1_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA0 = typename IteratorA0::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB0 = typename IteratorB0::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy0::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A0 operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B0 operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
using WarpFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpFragmentB1 = typename Operator1::FragmentB;
|
||||
|
||||
protected:
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bImplicitGemmPipelinedSmemAccumulator(
|
||||
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumulator tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
FragmentA0 tb_frag_A;
|
||||
FragmentB0 tb_frag_B0;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B0.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA0 warp_frag_A0[2];
|
||||
WarpFragmentB0 warp_frag_B0[2];
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
}
|
||||
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2],
|
||||
warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/// 2nd Implicit Gemm
|
||||
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
FragmentB1 tb_frag_B1;
|
||||
|
||||
tb_frag_B1.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
smem_write_stage_idx = 1;
|
||||
|
||||
// int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK * Base::kWarpGemmIterations1,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1)
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2],
|
||||
warp_frag_B1[warp_mma_k % 2], accum);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -180,8 +180,8 @@ class B2bMmaBase {
|
||||
using SharedStorage0 = SharedStorage<Shape0, Policy0>;
|
||||
using SharedStorage1 = SharedStorage<Shape1, Policy1>;
|
||||
union B2bMmaSharedStorage {
|
||||
SharedStorage0 sharedStorage0;
|
||||
SharedStorage1 sharedStorage1;
|
||||
SharedStorage0 shared_storage0;
|
||||
SharedStorage1 shared_storage1;
|
||||
};
|
||||
|
||||
|
||||
@ -197,7 +197,7 @@ class B2bMmaBase {
|
||||
/// Iterator to load a warp-scoped tile of B0 operand from shared memory
|
||||
typename Operator0::IteratorB warp_tile_iterator_B0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B0 operand from shared memory
|
||||
/// Iterator to load a warp-scoped tile of B1 operand from shared memory
|
||||
typename Operator1::IteratorB warp_tile_iterator_B1_;
|
||||
|
||||
public:
|
||||
@ -214,9 +214,9 @@ public:
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), lane_idx),
|
||||
warp_tile_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), lane_idx) {
|
||||
warp_tile_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), lane_idx),
|
||||
warp_tile_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), lane_idx) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@ -0,0 +1,174 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "threadblock/b2b_mma_base.h"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Shared Memory Accumulator Iterator
|
||||
typename SmemAccumulatorIterator0_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bMmaBaseSmemAccumulator :
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
|
||||
public:
|
||||
///< Base class
|
||||
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2>;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
using Shape1 = Shape1_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
|
||||
using SmemAccumulatorIterator0 = SmemAccumulatorIterator0_;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
/// Shared storage object needed by accumulator
|
||||
template<
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename Padding_
|
||||
>
|
||||
class AccumulatorSharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using Padding = Padding_;
|
||||
|
||||
/// Tensor reference to the accumulator
|
||||
using TensorRefAccum = TensorRef<Element, Layout>;
|
||||
|
||||
/// Shape of the accumulator matrix in shared memory
|
||||
using ShapeAccum = MatrixShape<Shape::kM + Padding::kRow,
|
||||
Shape::kN + Padding::kColumn>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for accumulator
|
||||
AlignedBuffer<Element, ShapeAccum::kCount> accum;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the Accum matrix
|
||||
CUTLASS_DEVICE
|
||||
static Layout LayoutAccum() {
|
||||
return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the Accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefAccum accum_ref() {
|
||||
return TensorRefAccum{accum.data(), LayoutAccum()};
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
using AccumulatorSharedStorage0 = AccumulatorSharedStorage<
|
||||
Shape0, typename SmemAccumulatorIterator0::Element,
|
||||
typename SmemAccumulatorIterator0::TensorLayout,
|
||||
typename SmemAccumulatorIterator0::Padding>;
|
||||
|
||||
struct B2bMmaSharedStorage {
|
||||
typename Base::B2bMmaSharedStorage b2b_mma_shared_storage;
|
||||
AccumulatorSharedStorage0 accumulator_shared_storage0;
|
||||
};
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaBaseSmemAccumulator(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage.b2b_mma_shared_storage, thread_idx, warp_idx, lane_idx) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -247,9 +247,9 @@ public:
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx)
|
||||
smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
@ -395,10 +395,8 @@ public:
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
if (gemm_k_iterations_0 == 0) {
|
||||
iterator_A0.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
@ -490,10 +488,8 @@ public:
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (gemm_k_iterations_0 == 0) {
|
||||
iterator_A0.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
@ -601,10 +597,8 @@ public:
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
if (gemm_k_iterations_0 == 0) {
|
||||
iterator_A0.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
@ -634,9 +628,7 @@ public:
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
if (gemm_k_iterations_1 == 0) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
@ -694,9 +686,7 @@ public:
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (gemm_k_iterations_1 == 0) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
@ -793,9 +783,7 @@ public:
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
if (gemm_k_iterations_1 == 1) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 1);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
|
||||
@ -207,9 +207,9 @@ public:
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx) {
|
||||
smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
|
||||
|
||||
@ -166,7 +166,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
23_ampere_gemm_operand_reduction_fusion
|
||||
ampere_gemm_operand_reduction_fusion.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,747 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
The example demenstrates how to reduce one of the operands of the GEMM along the k-dimension when
|
||||
computing GEMM. So the output also contains either a Mx1 or 1XN vector. It only works with Ampere
|
||||
HMMA 16x8x16 FP16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor
|
||||
core instructions.
|
||||
|
||||
Most of the reduction is done in gemm/warp level, see gemm/warp/mma_with_reduction_tensor_op.h
|
||||
A few bit of reduction is done in the epilouge before storing the vector, see
|
||||
epilogue/threadblock/epilogue_gemm_k_reduction.h
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h"
|
||||
#include "cutlass/reduction/device/reduce_split_k.h"
|
||||
#include "cutlass/reduction/kernel/reduce_split_k.h"
|
||||
#include "cutlass/reduction/thread/reduction_operators.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementOutput = cutlass::half_t; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::ColumnMajor;
|
||||
using LayoutInputB = cutlass::layout::RowMajor;
|
||||
using LayoutOutput = cutlass::layout::ColumnMajor;
|
||||
// Layout of the output vector
|
||||
using LayoutGemmKReduction = cutlass::layout::PitchLinear;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 4;
|
||||
|
||||
// Reduce A or B operand along the K dimension
|
||||
constexpr bool ReduceKForA = true;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction<
|
||||
ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 8,
|
||||
ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 8,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
ReduceKForA,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd
|
||||
>::GemmKernel;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Below is the reduction kernel used in the case of parallel spiit-k
|
||||
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;;
|
||||
|
||||
using ReduceOp = cutlass::reduction::thread::ReduceAdd<
|
||||
ElementAccumulator,
|
||||
ElementOutput,
|
||||
EpilogueOp::kCount
|
||||
>;
|
||||
|
||||
using ReduceGemmSplitKKernel = cutlass::reduction::kernel::ReduceSplitK<
|
||||
ReduceGemmSplitKShape,
|
||||
EpilogueOp,
|
||||
ReduceOp
|
||||
>;
|
||||
|
||||
using ReduceGemmSplitK = cutlass::reduction::device::ReduceSplitK<ReduceGemmSplitKKernel>;
|
||||
|
||||
using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using DummyEpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue,
|
||||
cutlass::epilogue::thread::ScaleType::Nothing>;
|
||||
|
||||
using ReduceVectorSplitKKernel = cutlass::reduction::kernel::ReduceSplitK<
|
||||
ReduceVectorSplitKShape,
|
||||
DummyEpilogueOp,
|
||||
ReduceOp
|
||||
>;
|
||||
|
||||
using ReduceVectorSplitK = cutlass::reduction::device::ReduceSplitK<ReduceVectorSplitKKernel>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int split_k_slices;
|
||||
bool parallel_split_k;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool benchmark;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size(1024, 1024, 1024),
|
||||
split_k_slices(1),
|
||||
parallel_split_k(false),
|
||||
reference_check(true),
|
||||
measure_performance(false),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(-1),
|
||||
beta(-1),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 8;
|
||||
|
||||
if ((problem_size.m() % kAlignment) ||
|
||||
(problem_size.n() % kAlignment) ||
|
||||
(problem_size.k() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::gemm::GemmCoord problem_size,
|
||||
int split_k_slices,
|
||||
bool parallel_split_k) {
|
||||
|
||||
this->problem_size = problem_size;
|
||||
this->split_k_slices = split_k_slices;
|
||||
this->parallel_split_k = parallel_split_k;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("parallel-split-k")) {
|
||||
parallel_split_k = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
cmd.get_cmd_line_argument("split-k-slices", split_k_slices);
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "28_ampere_gemm_bias_fusion example\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M\n"
|
||||
<< " --n <int> GEMM N\n"
|
||||
<< " --k <int> GEMM K\n"
|
||||
<< " --split-k-slices <int> Split K Slices\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --parallel-split-k If set (true), use parallel split K\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several problem sizes.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/28_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "ID,M,N,K,SplitK-Slices,Parallel-SplitK,Runtime";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "gemm_" << idx << ","
|
||||
<< options.problem_size.m() << ","
|
||||
<< options.problem_size.n() << ","
|
||||
<< options.problem_size.k() << ","
|
||||
<< options.split_k_slices << ","
|
||||
<< options.parallel_split_k << ","
|
||||
<< runtime_ms ;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.problem_size.mk());
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.problem_size.kn());
|
||||
|
||||
|
||||
// Create tensor C with dimensions 1x1x1xk which is the bias vector
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.problem_size.mn());
|
||||
|
||||
// Create tensor D used to store output from CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.problem_size.mn());
|
||||
// Create matrix D with dimensions M x N used to store output from reference
|
||||
// kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.problem_size.mn());
|
||||
|
||||
int reduce_vector_length = ReduceKForA ? options.problem_size.m() : options.problem_size.n();
|
||||
|
||||
cutlass::HostTensor<ElementOutput, LayoutGemmKReduction> tensor_reduction({reduce_vector_length, 1});
|
||||
cutlass::HostTensor<ElementOutput, LayoutGemmKReduction> tensor_ref_reduction({reduce_vector_length, 1});
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(4),
|
||||
ElementInputA(-4),
|
||||
0); // <- Fill tensor A on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(4),
|
||||
ElementInputB(-4),
|
||||
0); // <- Fill tensor B on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(4),
|
||||
ElementOutput(-4),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_reduction.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_reduction.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
tensor_reduction.sync_device();
|
||||
|
||||
// Initialize alpha for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta);
|
||||
|
||||
cutlass::gemm::GemmUniversalMode mode = options.parallel_split_k ?
|
||||
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel :
|
||||
cutlass::gemm::GemmUniversalMode::kGemm;
|
||||
|
||||
int batch_count = options.split_k_slices;
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{
|
||||
mode,
|
||||
options.problem_size,
|
||||
batch_count,
|
||||
{alpha, beta},
|
||||
tensor_a.device_ref().data(), // <- reference to tensor A on device
|
||||
tensor_b.device_ref().data(), // <- reference to tensor B on device
|
||||
tensor_c.device_ref().data(), // <- reference to matrix C on device
|
||||
tensor_d.device_ref().data(), // <- reference to matrix C on device
|
||||
tensor_reduction.device_ref().data(), // <- reference to tensor B on device
|
||||
options.problem_size.m() * options.problem_size.k(),
|
||||
options.problem_size.n() * options.problem_size.k(),
|
||||
options.problem_size.m() * options.problem_size.n(),
|
||||
options.problem_size.m() * options.problem_size.n(),
|
||||
reduce_vector_length,
|
||||
tensor_a.layout().stride(0),
|
||||
tensor_b.layout().stride(0),
|
||||
tensor_c.layout().stride(0),
|
||||
tensor_d.layout().stride(0),
|
||||
tensor_reduction.layout().stride(0)
|
||||
};
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check the problem size is supported or not
|
||||
result.status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
result.status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
result.status = gemm_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
if (options.parallel_split_k && batch_count > 1) {
|
||||
// reduce gemm
|
||||
int splitk_gemm_stride = options.problem_size.m();
|
||||
|
||||
cutlass::layout::RowMajor splitk_gemm_layout(splitk_gemm_stride);
|
||||
|
||||
void * workspace_gemm_ptr = workspace.get();
|
||||
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> workspace_gemm_tensorref(static_cast<ElementOutput *>(workspace_gemm_ptr), splitk_gemm_layout);
|
||||
|
||||
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> tensor_d_tensorref(tensor_d.device_ref().data(), splitk_gemm_layout);
|
||||
|
||||
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> tensor_c_tensorref(tensor_c.device_ref().data(), splitk_gemm_layout);
|
||||
|
||||
typename ReduceGemmSplitK::Arguments reduce_gemm_splitk_arguments{
|
||||
cutlass::MatrixCoord(options.problem_size.n(), options.problem_size.m()),
|
||||
batch_count,
|
||||
size_t(options.problem_size.m() * options.problem_size.n()),
|
||||
workspace_gemm_tensorref,
|
||||
tensor_d_tensorref,
|
||||
tensor_c_tensorref,
|
||||
{alpha, beta}
|
||||
};
|
||||
|
||||
ReduceGemmSplitK reduce_gemm_splitk_op;
|
||||
|
||||
result.status = reduce_gemm_splitk_op.initialize(reduce_gemm_splitk_arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = reduce_gemm_splitk_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
// reduce k vector
|
||||
cutlass::layout::RowMajor splitk_vector_layout(reduce_vector_length);
|
||||
|
||||
ElementOutput *workspace_vector_ptr = static_cast<ElementOutput *>(workspace_gemm_ptr) + batch_count * options.problem_size.m() * options.problem_size.n();
|
||||
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> workspace_vector_tensorref(workspace_vector_ptr, splitk_vector_layout);
|
||||
|
||||
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> tensor_reduction_tensorref(tensor_reduction.device_ref().data(), splitk_vector_layout);
|
||||
|
||||
cutlass::TensorRef<ElementOutput, cutlass::layout::RowMajor> tensor_nullptr_tensorref(nullptr, splitk_vector_layout);
|
||||
|
||||
typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments{
|
||||
cutlass::MatrixCoord(1, reduce_vector_length),
|
||||
batch_count,
|
||||
size_t(reduce_vector_length),
|
||||
workspace_vector_tensorref,
|
||||
tensor_reduction_tensorref,
|
||||
tensor_nullptr_tensorref,
|
||||
{1.0f, 0.0f}
|
||||
};
|
||||
|
||||
ReduceVectorSplitK reduce_vector_splitk_op;
|
||||
|
||||
result.status = reduce_vector_splitk_op.initialize(reduce_vector_splitk_arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = reduce_vector_splitk_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
//
|
||||
// Create instantiation for device reference conv kernel
|
||||
//
|
||||
if (options.reference_check) {
|
||||
// Launch device reference to compute strictly the product A * B
|
||||
cutlass::reference::device::Gemm<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator> gemm_device;
|
||||
|
||||
gemm_device
|
||||
(
|
||||
options.problem_size,
|
||||
alpha,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
beta,
|
||||
tensor_c.device_ref(),
|
||||
tensor_ref_d.device_ref()
|
||||
);
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d.sync_host();
|
||||
tensor_ref_d.sync_host();
|
||||
|
||||
tensor_reduction.sync_host();
|
||||
|
||||
// Compute bias + relu in host code
|
||||
if (ReduceKForA) {
|
||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||
for (int k = 0; k < options.problem_size.k(); ++k) {
|
||||
tensor_ref_reduction.at({m, 0}) +=
|
||||
tensor_a.at(cutlass::MatrixCoord(m, k));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < options.problem_size.k(); ++k) {
|
||||
for (int n = 0; n < options.problem_size.n(); ++n) {
|
||||
tensor_ref_reduction.at({n, 0}) +=
|
||||
tensor_b.at(cutlass::MatrixCoord(k, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
pass &= cutlass::reference::host::TensorEquals(tensor_ref_reduction.host_view(),
|
||||
tensor_reduction.host_view());
|
||||
|
||||
if (!pass) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
} else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
} else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
if (options.save_workspace) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "23_ampere_gemm_operand_reduction_fusion"
|
||||
<< options.problem_size.m() << "x" << options.problem_size.n() << "x" << options.problem_size.k()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
<< "A = \n" << tensor_a.host_view() << "\n\n"
|
||||
<< "B = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference D = \n" << tensor_ref_d.host_view() << "\n\n";
|
||||
output_workspace << "Reference reduction vector= \n" << tensor_ref_reduction.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
|
||||
output_workspace << "Computed reduction vector = \n" << tensor_reduction.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = gemm_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.benchmark) {
|
||||
// Benchmark several layers
|
||||
|
||||
struct Benchmark {
|
||||
int m, n, k, split_k_slices, parallel_split_k;
|
||||
} problem_sizes[] = {
|
||||
{4096, 6144, 4096, 1, false},
|
||||
};
|
||||
|
||||
Result::print_header(std::cout, options) << "\n";
|
||||
|
||||
int idx = 1;
|
||||
|
||||
for (auto const &problem_size : problem_sizes) {
|
||||
options.update({problem_size.m, problem_size.n, problem_size.k},
|
||||
problem_size.split_k_slices, problem_size.parallel_split_k);
|
||||
|
||||
Result result = profile(options);
|
||||
result.print(std::cout, idx, options) << "\n";
|
||||
|
||||
++idx;
|
||||
}
|
||||
} else {
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << "\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile(options);
|
||||
|
||||
Result::print_header(std::cout, options) << "\n";
|
||||
result.print(std::cout, 1, options) << "\n";
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
28
examples/24_gemm_grouped/CMakeLists.txt
Normal file
28
examples/24_gemm_grouped/CMakeLists.txt
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
24_gemm_grouped
|
||||
gemm_grouped.cu
|
||||
)
|
||||
|
||||
1326
examples/24_gemm_grouped/gemm_grouped.cu
Normal file
1326
examples/24_gemm_grouped/gemm_grouped.cu
Normal file
File diff suppressed because it is too large
Load Diff
28
examples/25_ampere_fprop_mainloop_fusion/CMakeLists.txt
Normal file
28
examples/25_ampere_fprop_mainloop_fusion/CMakeLists.txt
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
25_ampere_fprop_mainloop_fusion
|
||||
ampere_fprop_mainloop_fusion.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,751 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
|
||||
This example shows how to fuse per channel scale+bias+relu of the activations
|
||||
into the fprop mainloop.
|
||||
|
||||
Compared with original fprop kernel, this example has two more vectors, one for
|
||||
the scale and one for the bias. The length of the vectors are the same as the
|
||||
activation channel number. This kernels loads the vectors when the associated
|
||||
activation channels are loaded in the mainloop. Between reading the
|
||||
activations and scale/bias data from the shared memory and calling tensor core
|
||||
instructions, scale+bias+relu is computed in the register file.
|
||||
|
||||
This example is customized for Ampere 16816 fp16 tensor core instruction.
|
||||
Changing to different data types or different tensor core instruction require
|
||||
source code changing. See
|
||||
include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h for more
|
||||
technical details.
|
||||
|
||||
This example is modified based on 16_ampere_tensorop_conv2dfprop. The command
|
||||
line is the same.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop_fusion.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors
|
||||
using ElementOutput = float; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputScaleBias = cutlass::layout::RowMajor;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 4;
|
||||
|
||||
// This code section describe iterator algorithm selected is Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
using Conv2dFpropFusionKernel = typename cutlass::conv::kernel::DefaultConv2dFpropFusion<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementInputScaleBias, LayoutInputScaleBias,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
IteratorAlgorithm
|
||||
>::Kernel;
|
||||
|
||||
using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion<Conv2dFpropFusionKernel>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::Tensor4DCoord input_size;
|
||||
cutlass::Tensor4DCoord filter_size;
|
||||
cutlass::Tensor4DCoord padding;
|
||||
cutlass::MatrixCoord conv_stride;
|
||||
cutlass::MatrixCoord dilation;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool benchmark;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
input_size(1, 32, 32, 32),
|
||||
filter_size(32, 3, 3, 32),
|
||||
padding(1, 1, 1, 1),
|
||||
conv_stride(1, 1),
|
||||
dilation(1, 1),
|
||||
reference_check(true),
|
||||
measure_performance(false),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 8;
|
||||
|
||||
if ((input_size.c() % kAlignment) ||
|
||||
(filter_size.n() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding.h() != filter_size.h() / 2) ||
|
||||
(padding.w() != filter_size.w() / 2)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor4DCoord input_size,
|
||||
cutlass::Tensor4DCoord filter_size,
|
||||
cutlass::MatrixCoord stride) {
|
||||
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
conv_stride = stride;
|
||||
|
||||
padding.n() = filter_size.h() / 2;
|
||||
padding.h() = filter_size.h() / 2;
|
||||
padding.w() = filter_size.w() / 2;
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
filter_size.c() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
if (filter_size.h() == 3 && filter_size.w() == 3) {
|
||||
padding = {1, 1, 1, 1};
|
||||
}
|
||||
else {
|
||||
filter_size.h() = 1;
|
||||
filter_size.w() = 1;
|
||||
padding = {0, 0, 0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "25_ampere_fprop_mainloop_fusion example\n\n"
|
||||
<< " This example fuses scale+bias+relu of the activations into Ampere's\n"
|
||||
<< " Tensor Core operators on F16 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n <int> Input tensor extent N\n"
|
||||
<< " --h <int> Input tensor extent H\n"
|
||||
<< " --w <int> Input tensor extent W\n"
|
||||
<< " --c <int> Input tensor extent C\n"
|
||||
<< " --k <int> Filter extent K\n"
|
||||
<< " --r <int> Filter extent R\n"
|
||||
<< " --s <int> Filter extent S\n\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/25_ampere_fprop_mainloop_fusion/25_ampere_fprop_mainloop_fusion --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
<< "$ ./examples/25_ampere_fprop_mainloop_fusion/25_ampere_fprop_mainloop_fusion --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
input_size.n(),
|
||||
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
||||
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.h() << ","
|
||||
<< options.input_size.w() << ","
|
||||
<< options.input_size.c() << ","
|
||||
<< options.filter_size.n() << ","
|
||||
<< options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
<< options.conv_stride.row() << ","
|
||||
<< options.conv_stride.column() << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_transformed_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias>
|
||||
tensor_a_scale({1, options.input_size.c()});
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias>
|
||||
tensor_a_bias({1, options.input_size.c()});
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_c(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(3),
|
||||
ElementInputA(-4),
|
||||
0);
|
||||
|
||||
// Fill scale vector for tensor A on host with uniform-distribution random
|
||||
// data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a_scale.host_view(),
|
||||
1,
|
||||
ElementInputA(3),
|
||||
ElementInputA(-4),
|
||||
0);
|
||||
|
||||
// Fill bias vector for tensor A on host with uniform-distribution random
|
||||
// data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a_bias.host_view(),
|
||||
1,
|
||||
ElementInputA(3),
|
||||
ElementInputA(-4),
|
||||
0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(7),
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
|
||||
// Fill tensor C for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_a_scale.sync_device();
|
||||
tensor_a_bias.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_ref_c.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Construct Conv2dProblemSize with user defined output size
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
typename ImplicitGemmFusion::Arguments arguments{
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_a_scale.device_ref(),
|
||||
tensor_a_bias.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
ImplicitGemmFusion implicit_gemm_fusion_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
result.status = implicit_gemm_fusion_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_fusion_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on device...\n";
|
||||
|
||||
// Compute scale + bias + relu in host code
|
||||
for (int n = 0; n < options.input_size.n(); ++n) {
|
||||
for (int h = 0; h < options.input_size.h(); ++h) {
|
||||
for (int w = 0; w < options.input_size.w(); ++w) {
|
||||
for (int c = 0; c < options.input_size.c(); ++c) {
|
||||
tensor_transformed_a.at({n, h, w, c}) = std::max(
|
||||
ElementOutput(0), ElementOutput(tensor_a.at({n, h, w, c}) *
|
||||
tensor_a_scale.at({0, c}) +
|
||||
tensor_a_bias.at({0, c})));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor_transformed_a.sync_device();
|
||||
|
||||
// Compute with reference implementation
|
||||
cutlass::reference::device::Conv2dFprop<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
tensor_transformed_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_ref_c.device_ref(),
|
||||
options.alpha,
|
||||
options.beta
|
||||
);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_c.sync_host();
|
||||
tensor_ref_c.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_c.host_view(),
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
if (options.save_workspace) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "18_ampere_fused_fprop_batch_normalization_"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
<< "Input = \n" << tensor_a.host_view() << "\n\n"
|
||||
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_fusion_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major == 8 && props.minor == 0)) {
|
||||
std::cerr << "This test must run on SM80 A100.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.benchmark) {
|
||||
// Benchmark several layers
|
||||
|
||||
int batch_sizes[] = {34, 408};
|
||||
|
||||
struct Benchmark {
|
||||
int h, w, c, k, r, s, stride_h, stride_w;
|
||||
} layers[] = {
|
||||
{56, 56, 64, 256, 1, 1, 1, 1},
|
||||
{56, 56, 64, 64, 1, 1, 1, 1},
|
||||
{56, 56, 64, 64, 3, 3, 1, 1},
|
||||
{56, 56, 256, 64, 1, 1, 1, 1},
|
||||
{56, 56, 256, 512, 1, 1, 2, 2},
|
||||
{56, 56, 256, 128, 1, 1, 1, 1},
|
||||
{56, 56, 128, 128, 3, 3, 2, 2},
|
||||
{28, 28, 128, 512, 1, 1, 1, 1},
|
||||
{28, 28, 512, 128, 1, 1, 1, 1},
|
||||
{28, 28, 128, 128, 3, 3, 1, 1},
|
||||
{28, 28, 512, 1024, 1, 1, 2, 2},
|
||||
{28, 28, 512, 256, 1, 1, 1, 1},
|
||||
{28, 28, 256, 256, 3, 3, 2, 2},
|
||||
{14, 14, 256, 1024, 1, 1, 1, 1},
|
||||
{14, 14, 1024, 256, 1, 1, 1, 1},
|
||||
{14, 14, 256, 256, 3, 3, 1, 1},
|
||||
{14, 14, 1024, 2048, 1, 1, 2, 2},
|
||||
{14, 14, 1024, 512, 1, 1, 1, 1},
|
||||
{14, 14, 512, 512, 3, 3, 2, 2},
|
||||
{ 7, 7, 512, 2048, 1, 1, 1, 1},
|
||||
{ 7, 7, 2048, 512, 1, 1, 1, 1},
|
||||
{ 7, 7, 512, 512, 3, 3, 1, 1},
|
||||
};
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
|
||||
int idx = 1;
|
||||
|
||||
for (auto const &layer : layers) {
|
||||
for (auto N : batch_sizes) {
|
||||
options.update({N, layer.h, layer.w, layer.c},
|
||||
{layer.k, layer.r, layer.s, layer.c},
|
||||
{layer.stride_h, layer.stride_w});
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
result.print(std::cout, idx, options) << std::endl;
|
||||
}
|
||||
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
28
examples/26_ampere_wgrad_mainloop_fusion/CMakeLists.txt
Normal file
28
examples/26_ampere_wgrad_mainloop_fusion/CMakeLists.txt
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
26_ampere_wgrad_mainloop_fusion
|
||||
ampere_wgrad_mainloop_fusion.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,749 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
|
||||
This example shows how to fuse activation's per channel scale+bias+relu
|
||||
into the wgrad mainloop.
|
||||
|
||||
Compared with original fprop kernel, this example has two more vectors, one for
|
||||
the scale and one for the bias. The length of the vectors are the same as the
|
||||
activation channel number. This kernels loads the vectors when the associated
|
||||
activation channels are loaded in the mainloop. Between reading the
|
||||
activations and scale/bias data from the shared memory and calling tensor core
|
||||
instructions, scale+bias+relu is computed in the register file.
|
||||
|
||||
This example is customized for Ampere 16816 fp16 tensor core instruction.
|
||||
Changing to different data types or different tensor core instruction require
|
||||
source code changing. See
|
||||
include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h for more
|
||||
technical details.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_wgrad_fusion.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors
|
||||
using ElementOutput = float; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputScaleBias = cutlass::layout::RowMajor;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 5;
|
||||
|
||||
// This code section describe iterator algorithm selected is Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
using Conv2dWgradFusionKernel = typename cutlass::conv::kernel::DefaultConv2dWgradFusion<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementInputScaleBias, LayoutInputScaleBias,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
IteratorAlgorithm
|
||||
>::Kernel;
|
||||
|
||||
using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion<Conv2dWgradFusionKernel>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::Tensor4DCoord input_size;
|
||||
cutlass::Tensor4DCoord filter_size;
|
||||
cutlass::Tensor4DCoord padding;
|
||||
cutlass::MatrixCoord conv_stride;
|
||||
cutlass::MatrixCoord dilation;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool benchmark;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
input_size(1, 32, 32, 32),
|
||||
filter_size(32, 3, 3, 32),
|
||||
padding(1, 1, 1, 1),
|
||||
conv_stride(1, 1),
|
||||
dilation(1, 1),
|
||||
reference_check(true),
|
||||
measure_performance(false),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 8;
|
||||
|
||||
if ((input_size.c() % kAlignment) ||
|
||||
(filter_size.n() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding.h() != filter_size.h() / 2) ||
|
||||
(padding.w() != filter_size.w() / 2)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor4DCoord input_size,
|
||||
cutlass::Tensor4DCoord filter_size,
|
||||
cutlass::MatrixCoord stride) {
|
||||
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
conv_stride = stride;
|
||||
|
||||
padding.n() = filter_size.h() / 2;
|
||||
padding.h() = filter_size.h() / 2;
|
||||
padding.w() = filter_size.w() / 2;
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
filter_size.c() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
if (filter_size.h() == 3 && filter_size.w() == 3) {
|
||||
padding = {1, 1, 1, 1};
|
||||
}
|
||||
else {
|
||||
filter_size.h() = 1;
|
||||
filter_size.w() = 1;
|
||||
padding = {0, 0, 0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "26_ampere_fused_wgrad_batch_normalization example\n\n"
|
||||
<< " This example fuses scale+bias+relu from batch norm into Ampere's\n"
|
||||
<< " Tensor Core operators on F16 data types to compute\n"
|
||||
<< " backward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n <int> Input tensor extent N\n"
|
||||
<< " --h <int> Input tensor extent H\n"
|
||||
<< " --w <int> Input tensor extent W\n"
|
||||
<< " --c <int> Input tensor extent C\n"
|
||||
<< " --k <int> Filter extent K\n"
|
||||
<< " --r <int> Filter extent R\n"
|
||||
<< " --s <int> Filter extent S\n\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/26_ampere_fused_fprop_batch_normalization/26_ampere_fused_wgrad_batch_normalization --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
<< "$ ./examples/26_ampere_fused_fprop_batch_normalization/26_ampere_fused_wgrad_batch_normalization --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
input_size.n(),
|
||||
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
||||
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.h() << ","
|
||||
<< options.input_size.w() << ","
|
||||
<< options.input_size.c() << ","
|
||||
<< options.filter_size.n() << ","
|
||||
<< options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
<< options.conv_stride.row() << ","
|
||||
<< options.conv_stride.column() << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.output_size());
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.input_size);
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_transformed_b(options.input_size);
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias>
|
||||
tensor_b_scale({1, options.input_size.c()});
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias>
|
||||
tensor_b_bias({1, options.input_size.c()});
|
||||
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.filter_size);
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_c(options.filter_size);
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(3),
|
||||
ElementInputA(-4),
|
||||
0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(7),
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill scale vector for tensor B on host with uniform-distribution random
|
||||
// data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b_scale.host_view(),
|
||||
1,
|
||||
ElementInputA(3),
|
||||
ElementInputA(-4),
|
||||
0);
|
||||
|
||||
// Fill bias vector for tensor B on host with uniform-distribution random
|
||||
// data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b_bias.host_view(),
|
||||
1,
|
||||
ElementInputA(3),
|
||||
ElementInputA(-4),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
|
||||
// Fill tensor C for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_b_scale.sync_device();
|
||||
tensor_b_bias.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_ref_c.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Construct Conv2dProblemSize with user defined output size
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
typename ImplicitGemmFusion::Arguments arguments{
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_b_scale.device_ref(),
|
||||
tensor_b_bias.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
ImplicitGemmFusion implicit_gemm_fusion_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
result.status = implicit_gemm_fusion_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_fusion_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on device...\n";
|
||||
|
||||
// Compute scale + bias + relu in host code
|
||||
for (int n = 0; n < options.input_size.n(); ++n) {
|
||||
for (int h = 0; h < options.input_size.h(); ++h) {
|
||||
for (int w = 0; w < options.input_size.w(); ++w) {
|
||||
for (int c = 0; c < options.input_size.c(); ++c) {
|
||||
tensor_transformed_b.at({n, h, w, c}) = std::max(
|
||||
ElementOutput(0), ElementOutput(tensor_b.at({n, h, w, c}) *
|
||||
tensor_b_scale.at({0, c}) +
|
||||
tensor_b_bias.at({0, c})));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor_transformed_b.sync_device();
|
||||
|
||||
// Compute with reference implementation
|
||||
cutlass::reference::device::Conv2dWgrad<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_transformed_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_ref_c.device_ref(),
|
||||
options.alpha,
|
||||
options.beta
|
||||
);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_c.sync_host();
|
||||
tensor_ref_c.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_c.host_view(),
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
if (options.save_workspace) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "26_ampere_fused_wgrad_batch_normalization_"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
<< "Input = \n" << tensor_a.host_view() << "\n\n"
|
||||
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_fusion_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major == 8 && props.minor == 0)) {
|
||||
std::cerr << "This test must run on SM80 A100.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.benchmark) {
|
||||
// Benchmark several layers
|
||||
|
||||
int batch_sizes[] = {34, 408};
|
||||
|
||||
struct Benchmark {
|
||||
int h, w, c, k, r, s, stride_h, stride_w;
|
||||
} layers[] = {
|
||||
{56, 56, 64, 256, 1, 1, 1, 1},
|
||||
{56, 56, 64, 64, 1, 1, 1, 1},
|
||||
{56, 56, 64, 64, 3, 3, 1, 1},
|
||||
{56, 56, 256, 64, 1, 1, 1, 1},
|
||||
{56, 56, 256, 512, 1, 1, 2, 2},
|
||||
{56, 56, 256, 128, 1, 1, 1, 1},
|
||||
{56, 56, 128, 128, 3, 3, 2, 2},
|
||||
{28, 28, 128, 512, 1, 1, 1, 1},
|
||||
{28, 28, 512, 128, 1, 1, 1, 1},
|
||||
{28, 28, 128, 128, 3, 3, 1, 1},
|
||||
{28, 28, 512, 1024, 1, 1, 2, 2},
|
||||
{28, 28, 512, 256, 1, 1, 1, 1},
|
||||
{28, 28, 256, 256, 3, 3, 2, 2},
|
||||
{14, 14, 256, 1024, 1, 1, 1, 1},
|
||||
{14, 14, 1024, 256, 1, 1, 1, 1},
|
||||
{14, 14, 256, 256, 3, 3, 1, 1},
|
||||
{14, 14, 1024, 2048, 1, 1, 2, 2},
|
||||
{14, 14, 1024, 512, 1, 1, 1, 1},
|
||||
{14, 14, 512, 512, 3, 3, 2, 2},
|
||||
{ 7, 7, 512, 2048, 1, 1, 1, 1},
|
||||
{ 7, 7, 2048, 512, 1, 1, 1, 1},
|
||||
{ 7, 7, 512, 512, 3, 3, 1, 1},
|
||||
};
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
|
||||
int idx = 1;
|
||||
|
||||
for (auto const &layer : layers) {
|
||||
for (auto N : batch_sizes) {
|
||||
options.update({N, layer.h, layer.w, layer.c},
|
||||
{layer.k, layer.r, layer.s, layer.c},
|
||||
{layer.stride_h, layer.stride_w});
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
result.print(std::cout, idx, options) << std::endl;
|
||||
}
|
||||
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,744 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h)
|
||||
data types in tensor cores. One big advantage is that we can load in fp32 data and convert them
|
||||
implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate traditional
|
||||
fp32 data by using NVIDIA Ampere architecture.
|
||||
|
||||
We can use the tf32 mode of tensor core to emulate a fast accurate SGEMM kernel which is accelerated
|
||||
using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h).
|
||||
|
||||
The trick is very simple
|
||||
a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big
|
||||
big = convert_to_tf32(fp32)
|
||||
small = convert_to_tf32(fp32 - big)
|
||||
|
||||
a_small x b_small is discarded because they are too small.
|
||||
|
||||
This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32
|
||||
results (SGEMM using SIMT) and against FP64 results (DGEMM)
|
||||
|
||||
To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to
|
||||
OpMultiplyAddFastF32.
|
||||
|
||||
Now, we have several different flavors of sgemm now in the profiler for Ampere. Here are the difference
|
||||
|
||||
sgemm // CUDA core SIMT kernel. FP32 in, accumulated in FP32, FP32 out.
|
||||
s1688gemm // Use 3xTF32 to emulate FP32. FP32 in, converted in TF32-big and TF32-small internally,
|
||||
// accumulated in FP32, FP32 out.
|
||||
s1688tf32gemm // Use 1xTF32. FP32 in, converted to one TF32 internally, accumulated in FP32, FP32 out.
|
||||
s1688gemm_tf32 // TF32 in, accumulated in FP32, FP32 out.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_reduce.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
|
||||
int m, n, k;
|
||||
double l2_norm_3xtf32_vs_fp64;
|
||||
double l2_norm_1xtf32_vs_fp64;
|
||||
double l2_norm_fp32_vs_fp64;
|
||||
|
||||
// ctor
|
||||
Result(
|
||||
int m, int n, int k,
|
||||
double runtime_ms, double gflops,
|
||||
double l2_norm_3xtf32_vs_fp64,
|
||||
double l2_norm_1xtf32_vs_fp64,
|
||||
double l2_norm_fp32_vs_fp64) :
|
||||
m(m), n(n), k(k),
|
||||
runtime_ms(runtime_ms), gflops(gflops),
|
||||
l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64),
|
||||
l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64),
|
||||
l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {}
|
||||
|
||||
Result() {}
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
static void print_csv_header() {
|
||||
std::cout << "M,N,K,Runtime(ms),GFLOPS,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64" << std::endl;
|
||||
}
|
||||
|
||||
void print_csv_row() {
|
||||
std::cout << m << ","
|
||||
<< n << ","
|
||||
<< k << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops << ","
|
||||
<< l2_norm_3xtf32_vs_fp64 << ","
|
||||
<< l2_norm_1xtf32_vs_fp64 << ","
|
||||
<< l2_norm_fp32_vs_fp64 << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Result> results;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
float alpha;
|
||||
float beta;
|
||||
std::string rand_mode;
|
||||
|
||||
int iterations;
|
||||
int seed;
|
||||
bool benchmark;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({3456, 4096, 4096}),
|
||||
iterations(20),
|
||||
seed(1),
|
||||
alpha(1),
|
||||
beta(),
|
||||
rand_mode("uniform"),
|
||||
benchmark(false) { }
|
||||
|
||||
bool valid() {
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of F32 elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 4 elements.
|
||||
//
|
||||
int const kAlignment = 4;
|
||||
|
||||
if ((problem_size.m() % kAlignment) ||
|
||||
(problem_size.n() % kAlignment) ||
|
||||
(problem_size.k() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("seed", seed);
|
||||
cmd.get_cmd_line_argument("rand_mode", rand_mode);
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "27_ampere_3xtf32_fast_accurate_tensorop_gemm example\n\n"
|
||||
<< " This example uses the CUTLASS Library to emulate FP32 with TF32 tensorop GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha\n"
|
||||
<< " --beta <f32> Epilogue scalar beta\n\n"
|
||||
<< " --rand_mode <string> gauss / uniform*\n\n"
|
||||
<< " --seed <int> Random number seed (1*)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm --m=1024 --n=512 \\\n"
|
||||
<< " --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = problem_size.product();
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Column Major for
|
||||
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<128, 64, 16>; // <- threadblock tile M = 128, N = 128, K = 16
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M = 64, N = 64, K = 16
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
float, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<float>::value, // <- the number of elements per vectorized
|
||||
// memory access. For a byte, it's 16
|
||||
// elements. This becomes the vector width of
|
||||
// math instructions in the epilogue too
|
||||
float, // <- data type of accumulator
|
||||
float>; // <- data type for alpha/beta in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
// Alignment
|
||||
constexpr int Alignment = 4;
|
||||
|
||||
//
|
||||
// Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64)
|
||||
//
|
||||
|
||||
// Gemm_3xTF32
|
||||
using Gemm_3xTF32 = cutlass::gemm::device::Gemm<
|
||||
float,
|
||||
LayoutInputA,
|
||||
float,
|
||||
LayoutInputB,
|
||||
float,
|
||||
LayoutOutput,
|
||||
float,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
Alignment,
|
||||
Alignment,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddFastF32>;
|
||||
|
||||
// Gemm_1xTF32
|
||||
using Gemm_1xTF32 = cutlass::gemm::device::Gemm<
|
||||
float,
|
||||
LayoutInputA,
|
||||
float,
|
||||
LayoutInputB,
|
||||
float,
|
||||
LayoutOutput,
|
||||
float,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
Alignment,
|
||||
Alignment,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAdd>;
|
||||
|
||||
// Gemm_F64
|
||||
using Gemm_F64 = cutlass::reference::device::Gemm<
|
||||
double,
|
||||
LayoutInputA,
|
||||
double,
|
||||
LayoutInputB,
|
||||
double,
|
||||
LayoutOutput,
|
||||
double,
|
||||
double>;
|
||||
|
||||
// Gemm_F32
|
||||
using Gemm_F32 = cutlass::reference::device::Gemm<
|
||||
float,
|
||||
LayoutInputA,
|
||||
float,
|
||||
LayoutInputB,
|
||||
float,
|
||||
LayoutOutput,
|
||||
float,
|
||||
float>;
|
||||
|
||||
bool run(Options &options) {
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size = options.problem_size;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 1. Initialize F32 Precision input tensors using CUTLASS helper functions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
cutlass::HostTensor<float, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<float, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
|
||||
if (options.rand_mode == "uniform") {
|
||||
const float min = -1;
|
||||
const float max = 1;
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a_F32.host_view(),
|
||||
options.seed,
|
||||
double(max),
|
||||
double(min)); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b_F32.host_view(),
|
||||
options.seed,
|
||||
double(max),
|
||||
double(min)); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c_F32.host_view(),
|
||||
options.seed,
|
||||
double(max),
|
||||
double(min)); // <- Fill matrix C on host with uniform-distribution random data
|
||||
} else if (options.rand_mode == "gauss") {
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomGaussian(
|
||||
tensor_a_F32.host_view(),
|
||||
options.seed,
|
||||
double(0),
|
||||
double(5)); // <- Fill matrix A on host with gaussian-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomGaussian(
|
||||
tensor_b_F32.host_view(),
|
||||
options.seed,
|
||||
double(0),
|
||||
double(5)); // <- Fill matrix B on host with gaussian-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomGaussian(
|
||||
tensor_c_F32.host_view(),
|
||||
options.seed,
|
||||
double(0),
|
||||
double(5)); // <- Fill matrix C on host with gaussian-distribution random data
|
||||
}
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d_F32.host_view()); // <- fill matrix D on host with zeros
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F32.sync_device();
|
||||
tensor_b_F32.sync_device();
|
||||
tensor_c_F32.sync_device();
|
||||
tensor_d_F32.sync_device();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 2. Initialize F64 tensors using the same values used for F32
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Gemm input operands (A, B, C)
|
||||
cutlass::HostTensor<double, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<double, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
|
||||
// Gemm output (D) for GEMM_F64
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
// Gemm output (D) for GEMM_3xTF32
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
// Gemm output (D) for GEMM_1xTF32
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_1xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
|
||||
// Copy values from the DP tensors
|
||||
cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F64.sync_device();
|
||||
tensor_b_F64.sync_device();
|
||||
tensor_c_F64.sync_device();
|
||||
tensor_d_F64.sync_device();
|
||||
tensor_d_3xTF32.sync_device();
|
||||
tensor_d_1xTF32.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
float alpha = float(options.alpha);
|
||||
float beta = float(options.beta);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 3. Run 3xTF32 kernel within a profiling loop
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm_3xTF32::Arguments arguments_3xtf32{problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a_F32.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b_F32.device_ref(), // <- reference to matrix B on device
|
||||
tensor_c_F32.device_ref(), // <- reference to matrix C on device
|
||||
tensor_d_3xTF32.device_ref(), // <- reference to matrix D on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size_3xtf32 = Gemm_3xTF32::get_workspace_size(arguments_3xtf32);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace_3xtf32(workspace_size_3xtf32);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm_3xTF32 gemm_op_3xTF32;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_3xtf32 = gemm_op_3xTF32.can_implement(arguments_3xtf32);
|
||||
CUTLASS_CHECK(status_3xtf32);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status_3xtf32 = gemm_op_3xTF32.initialize(arguments_3xtf32, workspace_3xtf32.get());
|
||||
CUTLASS_CHECK(status_3xtf32);
|
||||
|
||||
// Result structure
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// Run profiling loop
|
||||
//
|
||||
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
// Launch initialized CUTLASS kernel
|
||||
status_3xtf32 = gemm_op_3xTF32();
|
||||
CUTLASS_CHECK(status_3xtf32);
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
result.m = problem_size.m();
|
||||
result.n = problem_size.n();
|
||||
result.k = problem_size.k();
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
tensor_d_3xTF32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 4. Run TF32 kernel without profiling loop
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm_1xTF32::Arguments arguments_1xtf32{problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a_F32.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b_F32.device_ref(), // <- reference to matrix B on device
|
||||
tensor_c_F32.device_ref(), // <- reference to matrix C on device
|
||||
tensor_d_1xTF32.device_ref(), // <- reference to matrix D on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size_1xtf32 = Gemm_1xTF32::get_workspace_size(arguments_1xtf32);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace_1xtf32(workspace_size_1xtf32);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm_1xTF32 gemm_op_1xtf32;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32);
|
||||
CUTLASS_CHECK(status_1xtf32);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status_1xtf32 = gemm_op_1xtf32.initialize(arguments_1xtf32, workspace_1xtf32.get());
|
||||
CUTLASS_CHECK(status_1xtf32);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status_1xtf32 = gemm_op_1xtf32();
|
||||
CUTLASS_CHECK(status_1xtf32);
|
||||
|
||||
tensor_d_1xTF32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Run reference kernel (F64)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
Gemm_F64 gemm_f64;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_f64(problem_size,
|
||||
alpha,
|
||||
tensor_a_F64.device_ref(),
|
||||
tensor_b_F64.device_ref(),
|
||||
beta,
|
||||
tensor_c_F64.device_ref(),
|
||||
tensor_d_F64.device_ref());
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d_F64.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Run reference kernel (F32)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
Gemm_F32 gemm_f32;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_f32(problem_size,
|
||||
alpha,
|
||||
tensor_a_F32.device_ref(),
|
||||
tensor_b_F32.device_ref(),
|
||||
beta,
|
||||
tensor_c_F32.device_ref(),
|
||||
tensor_d_F32.device_ref());
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d_F32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/////// Compute l2 norms
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// l2 norm 3xTF32 vs F64
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_d_3xTF32_in_F64(problem_size.mn());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view());
|
||||
|
||||
result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view());
|
||||
|
||||
// l2 norm 1xTF32 vs F64
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_d_1xTF32_in_F64(problem_size.mn());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view());
|
||||
|
||||
result.l2_norm_1xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view());
|
||||
|
||||
// l2 norm F32 vs F64
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_d_F32_in_F64(problem_size.mn());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view());
|
||||
|
||||
result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view());
|
||||
|
||||
results.push_back(result);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
|
||||
std::cout << std::fixed;
|
||||
std::cout.precision(4);
|
||||
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
|
||||
std::cout.precision(2);
|
||||
std::cout << "GFLOPs: " << result.gflops << std::endl;
|
||||
std::cout << "Normalized L2 norm of" << std::endl;
|
||||
std::cout.precision(8);
|
||||
std::cout << std::scientific
|
||||
<< " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl
|
||||
<< " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl
|
||||
<< " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, const char **argv) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool result = true;
|
||||
|
||||
if (options.benchmark) {
|
||||
for (int k = 4; k <= 65536; k *= 2) {
|
||||
|
||||
options.problem_size[2] = k;
|
||||
|
||||
printf("Gemm problem size: %d x %d x %d\n", \
|
||||
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
|
||||
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
result &= run(options);
|
||||
}
|
||||
} else {
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
result = run(options);
|
||||
}
|
||||
|
||||
if (!result) return -1;
|
||||
|
||||
std::cout << std::endl << "CSV results" << std::endl;
|
||||
Result::print_csv_header();
|
||||
for(auto &r : results)
|
||||
r.print_csv_row();
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
27_ampere_3xtf32_fast_accurate_tensorop_gemm
|
||||
27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
28_ampere_3xtf32_fast_accurate_tensorop_fprop
|
||||
ampere_3xtf32_fast_accurate_tensorop_fprop.cu
|
||||
)
|
||||
@ -0,0 +1,815 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
|
||||
This example adopts example 16 to use 3xTF32 to bring FP32 accuracy with 2x performance
|
||||
compared with CUDA Cores. See example 27 for the trick of 3xTF32.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = float; // Data type of elements in input tensor
|
||||
using ElementInputB = float; // Data type of elements in input tensor
|
||||
using ElementOutput = float; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
// This code section describe iterator algorithm selected is Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
// 3xTF32 Fprop
|
||||
using Conv2dFpropKernel_3xTF32 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
// Only thing needs to be changed from normal Fprop
|
||||
cutlass::arch::OpMultiplyAddFastF32,
|
||||
IteratorAlgorithm
|
||||
>::Kernel;
|
||||
|
||||
// 1xTF32 Fprop
|
||||
using Conv2dFpropKernel_1xTF32 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
IteratorAlgorithm
|
||||
>::Kernel;
|
||||
|
||||
using ImplicitGemm_3xTF32 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel_3xTF32>;
|
||||
using ImplicitGemm_1xTF32 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel_1xTF32>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::Tensor4DCoord input_size;
|
||||
cutlass::Tensor4DCoord filter_size;
|
||||
cutlass::Tensor4DCoord padding;
|
||||
cutlass::MatrixCoord conv_stride;
|
||||
cutlass::MatrixCoord dilation;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool benchmark;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
input_size(1, 32, 32, 32),
|
||||
filter_size(32, 3, 3, 32),
|
||||
padding(1, 1, 1, 1),
|
||||
conv_stride(1, 1),
|
||||
dilation(1, 1),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 4;
|
||||
|
||||
if ((input_size.c() % kAlignment) ||
|
||||
(filter_size.n() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding.h() != filter_size.h() / 2) ||
|
||||
(padding.w() != filter_size.w() / 2)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor4DCoord input_size,
|
||||
cutlass::Tensor4DCoord filter_size) {
|
||||
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
|
||||
padding.n() = filter_size.h() / 2;
|
||||
padding.h() = filter_size.h() / 2;
|
||||
padding.w() = filter_size.w() / 2;
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
filter_size.c() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
if (filter_size.h() == 3 && filter_size.w() == 3) {
|
||||
padding = {1, 1, 1, 1};
|
||||
}
|
||||
else {
|
||||
filter_size.h() = 1;
|
||||
filter_size.w() = 1;
|
||||
padding = {0, 0, 0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "28_ampere_3xtf32_fast_accurate_tensorop_fprop example\n\n"
|
||||
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n <int> Input tensor extent N\n"
|
||||
<< " --h <int> Input tensor extent H\n"
|
||||
<< " --w <int> Input tensor extent W\n"
|
||||
<< " --c <int> Input tensor extent C\n"
|
||||
<< " --k <int> Filter extent K\n"
|
||||
<< " --r <int> Filter extent R\n"
|
||||
<< " --s <int> Filter extent S\n\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/28_ampere_3xtf32_fast_accurate_tensorop_fprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
<< "$ ./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/28_ampere_3xtf32_fast_accurate_tensorop_fprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
input_size.n(),
|
||||
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
||||
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
|
||||
double l2_norm_3xtf32_vs_fp64;
|
||||
double l2_norm_1xtf32_vs_fp64;
|
||||
double l2_norm_fp32_vs_fp64;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
error(cudaSuccess),
|
||||
l2_norm_3xtf32_vs_fp64(0),
|
||||
l2_norm_1xtf32_vs_fp64(0),
|
||||
l2_norm_fp32_vs_fp64(0) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.h() << ","
|
||||
<< options.input_size.w() << ","
|
||||
<< options.input_size.c() << ","
|
||||
<< options.filter_size.n() << ","
|
||||
<< options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops << ","
|
||||
<< l2_norm_3xtf32_vs_fp64 << ","
|
||||
<< l2_norm_1xtf32_vs_fp64 << ","
|
||||
<< l2_norm_fp32_vs_fp64;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 1. Initialize F32 Precision input tensors using CUTLASS helper functions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<float, LayoutInputA> tensor_a_F32(options.input_size);
|
||||
cutlass::HostTensor<float, LayoutInputB> tensor_b_F32(options.filter_size);
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_c_F32(options.output_size());
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_F32(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a_F32.host_view(),
|
||||
1,
|
||||
ElementInputA(7),
|
||||
ElementInputA(-8));
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b_F32.host_view(),
|
||||
1,
|
||||
ElementInputB(7),
|
||||
ElementInputB(-8));
|
||||
|
||||
// Fill tensor C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c_F32.host_view(),
|
||||
1,
|
||||
ElementInputB(7),
|
||||
ElementInputB(-8));
|
||||
|
||||
// Fill tensor D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d_F32.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F32.sync_device();
|
||||
tensor_b_F32.sync_device();
|
||||
tensor_c_F32.sync_device();
|
||||
tensor_d_F32.sync_device();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 2. Initialize F32 Precision input tensors using CUTLASS helper functions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<double, LayoutInputA> tensor_a_F64(options.input_size);
|
||||
cutlass::HostTensor<double, LayoutInputB> tensor_b_F64(options.filter_size);
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_c_F64(options.output_size());
|
||||
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_d_F64(options.output_size());
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_3xTF32(options.output_size());
|
||||
cutlass::HostTensor<float, LayoutOutput> tensor_d_1xTF32(options.output_size());
|
||||
|
||||
// Copy values from the DP tensors
|
||||
cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F64.sync_device();
|
||||
tensor_b_F64.sync_device();
|
||||
tensor_c_F64.sync_device();
|
||||
tensor_d_F64.sync_device();
|
||||
tensor_d_3xTF32.sync_device();
|
||||
tensor_d_1xTF32.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Construct Conv2dProblemSize with user defined output size
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 3. Run 3xTF32 kernel within a profiling loop
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Construct ImplicitGemm::Argument structure with conv2d
|
||||
// problem size, data pointers, and epilogue values
|
||||
typename ImplicitGemm_3xTF32::Arguments arguments_3xTF32{
|
||||
problem_size,
|
||||
tensor_a_F32.device_ref(),
|
||||
tensor_b_F32.device_ref(),
|
||||
tensor_c_F32.device_ref(),
|
||||
tensor_d_3xTF32.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
ImplicitGemm_3xTF32 implicit_gemm_op_3xTF32;
|
||||
|
||||
size_t workspace_size_3xTF32 = implicit_gemm_op_3xTF32.get_workspace_size(arguments_3xTF32);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace_3xTF32(workspace_size_3xTF32);
|
||||
|
||||
result.status = implicit_gemm_op_3xTF32.can_implement(arguments_3xTF32);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = implicit_gemm_op_3xTF32.initialize(arguments_3xTF32, workspace_3xTF32.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_op_3xTF32();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_op_3xTF32();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
tensor_d_3xTF32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 4. Run 1xTF32 kernel within a profiling loop
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Construct ImplicitGemm::Argument structure with conv2d
|
||||
// problem size, data pointers, and epilogue values
|
||||
typename ImplicitGemm_1xTF32::Arguments arguments_1xTF32{
|
||||
problem_size,
|
||||
tensor_a_F32.device_ref(),
|
||||
tensor_b_F32.device_ref(),
|
||||
tensor_c_F32.device_ref(),
|
||||
tensor_d_1xTF32.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
ImplicitGemm_1xTF32 implicit_gemm_op_1xTF32;
|
||||
|
||||
size_t workspace_size_1xTF32 = implicit_gemm_op_1xTF32.get_workspace_size(arguments_1xTF32);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace_1xTF32(workspace_size_1xTF32);
|
||||
|
||||
result.status = implicit_gemm_op_1xTF32.can_implement(arguments_1xTF32);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = implicit_gemm_op_1xTF32.initialize(arguments_1xTF32, workspace_1xTF32.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_op_1xTF32();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
tensor_d_1xTF32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Run reference kernel (F64)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
double,
|
||||
LayoutInputA,
|
||||
double,
|
||||
LayoutInputB,
|
||||
double,
|
||||
LayoutOutput,
|
||||
double,
|
||||
double
|
||||
>(
|
||||
cutlass::conv::Operator::kFprop,
|
||||
problem_size,
|
||||
tensor_a_F64.device_ref(),
|
||||
tensor_b_F64.device_ref(),
|
||||
tensor_c_F64.device_ref(),
|
||||
tensor_d_F64.device_ref(),
|
||||
options.alpha,
|
||||
options.beta);
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d_F64.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Run reference kernel (F32)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
float,
|
||||
LayoutInputA,
|
||||
float,
|
||||
LayoutInputB,
|
||||
float,
|
||||
LayoutOutput,
|
||||
float,
|
||||
float
|
||||
>(
|
||||
cutlass::conv::Operator::kFprop,
|
||||
problem_size,
|
||||
tensor_a_F32.device_ref(),
|
||||
tensor_b_F32.device_ref(),
|
||||
tensor_c_F32.device_ref(),
|
||||
tensor_d_F32.device_ref(),
|
||||
options.alpha,
|
||||
options.beta);
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d_F32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/////// Compute l2 norms
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// l2 norm 3xTF32 vs F64
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_d_3xTF32_in_F64(options.output_size());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view());
|
||||
|
||||
result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view());
|
||||
|
||||
// l2 norm 1xTF32 vs F64
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_d_1xTF32_in_F64(options.output_size());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view());
|
||||
|
||||
result.l2_norm_1xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view());
|
||||
|
||||
// l2 norm F32 vs F64
|
||||
cutlass::HostTensor<double, LayoutOutput> tensor_d_F32_in_F64(options.output_size());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view());
|
||||
|
||||
result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view());
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
if (options.save_workspace) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "28_ampere_3xtf32_fast_accurate_tensorop_fprop_"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
<< "Input = \n" << tensor_a_F32.host_view() << "\n\n"
|
||||
<< "Filters = \n" << tensor_b_F32.host_view() << "\n\n";
|
||||
|
||||
output_workspace << "TF32x3 = \n" << tensor_d_3xTF32.host_view() << std::endl;
|
||||
output_workspace << "TF32x1 = \n" << tensor_d_1xTF32.host_view() << std::endl;
|
||||
output_workspace << "FP32 = \n" << tensor_d_F32.host_view() << std::endl;
|
||||
output_workspace << "FP64 = \n" << tensor_d_F64.host_view() << "\n\n";
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.benchmark) {
|
||||
// Benchmark several layers
|
||||
|
||||
int batch_sizes[] = {1, 32, 64, 128, 256};
|
||||
|
||||
struct Benchmark {
|
||||
int h, w, c, k, r, s;
|
||||
} layers[] = {
|
||||
{56, 56, 64, 256, 1, 1},
|
||||
{56, 56, 64, 64, 1, 1},
|
||||
{56, 56, 64, 64, 3, 3},
|
||||
{56, 56, 256, 64, 1, 1},
|
||||
{56, 56, 256, 512, 1, 1},
|
||||
{56, 56, 256, 128, 1, 1},
|
||||
{28, 28, 128, 128, 3, 3},
|
||||
{28, 28, 128, 512, 1, 1},
|
||||
{28, 28, 512, 128, 1, 1},
|
||||
{28, 28, 512, 1024, 1, 1},
|
||||
{28, 28, 512, 256, 1, 1},
|
||||
{14, 14, 256, 256, 3, 3},
|
||||
{14, 14, 256, 1024, 1, 1},
|
||||
{14, 14, 1024, 256, 1, 1},
|
||||
{14, 14, 1024, 2048, 1, 1},
|
||||
{14, 14, 1024, 512, 1, 1},
|
||||
{7, 7, 512, 512, 3, 3},
|
||||
};
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
|
||||
int idx = 1;
|
||||
|
||||
for (auto const &layer : layers) {
|
||||
for (auto N : batch_sizes) {
|
||||
|
||||
options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c});
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
result.print(std::cout, idx, options) << std::endl;
|
||||
}
|
||||
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,686 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
This example is almost the same as example 27 which uses 3xTF32 to run GEMM. The only
|
||||
difference is that this example uses 3xtf32 on complex gemm.
|
||||
|
||||
To enable this feature, the only change needs to make is to change OpMultiplyAddComplex
|
||||
to OpMultiplyAddComplexFastF32.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/host/tensor_reduce.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
|
||||
int m, n, k;
|
||||
double l2_norm_3xtf32_vs_fp64;
|
||||
double l2_norm_1xtf32_vs_fp64;
|
||||
double l2_norm_fp32_vs_fp64;
|
||||
|
||||
// ctor
|
||||
Result(
|
||||
int m, int n, int k,
|
||||
double runtime_ms, double gflops,
|
||||
double l2_norm_3xtf32_vs_fp64,
|
||||
double l2_norm_1xtf32_vs_fp64,
|
||||
double l2_norm_fp32_vs_fp64) :
|
||||
m(m), n(n), k(k),
|
||||
runtime_ms(runtime_ms), gflops(gflops),
|
||||
l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64),
|
||||
l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64),
|
||||
l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {}
|
||||
|
||||
Result() {}
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
static void print_csv_header() {
|
||||
std::cout << "M,N,K,Runtime(ms),GFLOPS,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64" << std::endl;
|
||||
}
|
||||
|
||||
void print_csv_row() {
|
||||
std::cout << m << ","
|
||||
<< n << ","
|
||||
<< k << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops << ","
|
||||
<< l2_norm_3xtf32_vs_fp64 << ","
|
||||
<< l2_norm_1xtf32_vs_fp64 << ","
|
||||
<< l2_norm_fp32_vs_fp64 << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Result> results;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
float alpha;
|
||||
float beta;
|
||||
std::string rand_mode;
|
||||
|
||||
int iterations;
|
||||
int seed;
|
||||
bool benchmark;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({3456, 4096, 4096}),
|
||||
iterations(20),
|
||||
seed(1),
|
||||
alpha(1),
|
||||
beta(),
|
||||
rand_mode("uniform"),
|
||||
benchmark(false) { }
|
||||
|
||||
bool valid() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("seed", seed);
|
||||
cmd.get_cmd_line_argument("rand_mode", rand_mode);
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm example\n\n"
|
||||
<< " This example uses the CUTLASS Library to emulate FP32 complex GEMM computations with TF32 tensor cores.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha\n"
|
||||
<< " --beta <f32> Epilogue scalar beta\n\n"
|
||||
<< " --rand_mode <string> gauss / uniform*\n\n"
|
||||
<< " --seed <int> Random number seed (1*)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_complex_gemm --m=1024 --n=512 \\\n"
|
||||
<< " --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = problem_size.product();
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Column Major for
|
||||
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::ColumnMajor;
|
||||
using LayoutInputB = cutlass::layout::RowMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 128, N = 128, K = 16
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 64, N = 64, K = 16
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
cutlass::complex<float>, // <- data type of output matrix
|
||||
1, // <- the number of elements per vectorized
|
||||
// memory access. For a byte, it's 16
|
||||
// elements. This becomes the vector width of
|
||||
// math instructions in the epilogue too
|
||||
cutlass::complex<float>, // <- data type of accumulator
|
||||
cutlass::complex<float>>; // <- data type for alpha/beta in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
// Transform
|
||||
constexpr cutlass::ComplexTransform TransformA = cutlass::ComplexTransform::kNone;
|
||||
constexpr cutlass::ComplexTransform TransformB = cutlass::ComplexTransform::kNone;
|
||||
|
||||
//
|
||||
// Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64)
|
||||
//
|
||||
|
||||
// Gemm_3xTF32
|
||||
using Gemm_3xTF32 = cutlass::gemm::device::GemmComplex<
|
||||
cutlass::complex<float>,
|
||||
LayoutInputA,
|
||||
cutlass::complex<float>,
|
||||
LayoutInputB,
|
||||
cutlass::complex<float>,
|
||||
LayoutOutput,
|
||||
cutlass::complex<float>,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
TransformA,
|
||||
TransformB,
|
||||
cutlass::arch::OpMultiplyAddComplexFastF32>;
|
||||
|
||||
// Gemm_1xTF32
|
||||
using Gemm_1xTF32 = cutlass::gemm::device::GemmComplex<
|
||||
cutlass::complex<float>,
|
||||
LayoutInputA,
|
||||
cutlass::complex<float>,
|
||||
LayoutInputB,
|
||||
cutlass::complex<float>,
|
||||
LayoutOutput,
|
||||
cutlass::complex<float>,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
TransformA,
|
||||
TransformB,
|
||||
cutlass::arch::OpMultiplyAddComplex>;
|
||||
|
||||
bool run(Options &options) {
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size = options.problem_size;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 1. Initialize F32 Precision input tensors using CUTLASS helper functions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
|
||||
if (options.rand_mode == "uniform") {
|
||||
const float min = -1;
|
||||
const float max = 1;
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a_F32.host_view(),
|
||||
options.seed,
|
||||
double(max),
|
||||
double(min)); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b_F32.host_view(),
|
||||
options.seed,
|
||||
double(max),
|
||||
double(min)); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c_F32.host_view(),
|
||||
options.seed,
|
||||
double(max),
|
||||
double(min)); // <- Fill matrix C on host with uniform-distribution random data
|
||||
} else if (options.rand_mode == "gauss") {
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomGaussian(
|
||||
tensor_a_F32.host_view(),
|
||||
options.seed,
|
||||
double(0),
|
||||
double(5)); // <- Fill matrix A on host with gaussian-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomGaussian(
|
||||
tensor_b_F32.host_view(),
|
||||
options.seed,
|
||||
double(0),
|
||||
double(5)); // <- Fill matrix B on host with gaussian-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomGaussian(
|
||||
tensor_c_F32.host_view(),
|
||||
options.seed,
|
||||
double(0),
|
||||
double(5)); // <- Fill matrix C on host with gaussian-distribution random data
|
||||
}
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d_F32.host_view()); // <- fill matrix D on host with zeros
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F32.sync_device();
|
||||
tensor_b_F32.sync_device();
|
||||
tensor_c_F32.sync_device();
|
||||
tensor_d_F32.sync_device();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 2. Initialize F64 tensors using the same values used for F32
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Gemm input operands (A, B, C)
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
|
||||
// Gemm output (D) for GEMM_F64
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutOutput> tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
// Gemm output (D) for GEMM_3xTF32
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutOutput> tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
// Gemm output (D) for GEMM_1xTF32
|
||||
cutlass::HostTensor<cutlass::complex<float>, LayoutOutput> tensor_d_1xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N
|
||||
|
||||
// Copy values from the DP tensors
|
||||
cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a_F64.sync_device();
|
||||
tensor_b_F64.sync_device();
|
||||
tensor_c_F64.sync_device();
|
||||
tensor_d_F64.sync_device();
|
||||
tensor_d_3xTF32.sync_device();
|
||||
tensor_d_1xTF32.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
cutlass::complex<float> alpha = cutlass::complex<float>(options.alpha);
|
||||
cutlass::complex<float> beta = cutlass::complex<float>(options.beta);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 3. Run 3xTF32 kernel within a profiling loop
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm_3xTF32::Arguments arguments_3xtf32{problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a_F32.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b_F32.device_ref(), // <- reference to matrix B on device
|
||||
tensor_c_F32.device_ref(), // <- reference to matrix C on device
|
||||
tensor_d_3xTF32.device_ref(), // <- reference to matrix D on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size_3xtf32 = Gemm_3xTF32::get_workspace_size(arguments_3xtf32);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace_3xtf32(workspace_size_3xtf32);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm_3xTF32 gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_3xtf32 = gemm_op.can_implement(arguments_3xtf32);
|
||||
CUTLASS_CHECK(status_3xtf32);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status_3xtf32 = gemm_op.initialize(arguments_3xtf32, workspace_3xtf32.get());
|
||||
CUTLASS_CHECK(status_3xtf32);
|
||||
|
||||
// Result structure
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// Run profiling loop
|
||||
//
|
||||
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
// Launch initialized CUTLASS kernel
|
||||
status_3xtf32 = gemm_op();
|
||||
CUTLASS_CHECK(status_3xtf32);
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
result.m = problem_size.m();
|
||||
result.n = problem_size.n();
|
||||
result.k = problem_size.k();
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
tensor_d_3xTF32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// 4. Run TF32 kernel without profiling loop
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm_1xTF32::Arguments arguments_1xtf32{problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a_F32.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b_F32.device_ref(), // <- reference to matrix B on device
|
||||
tensor_c_F32.device_ref(), // <- reference to matrix C on device
|
||||
tensor_d_1xTF32.device_ref(), // <- reference to matrix D on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size_1xtf32 = Gemm_1xTF32::get_workspace_size(arguments_1xtf32);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace_1xtf32(workspace_size_1xtf32);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm_1xTF32 gemm_op_1xtf32;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32);
|
||||
CUTLASS_CHECK(status_1xtf32);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status_1xtf32 = gemm_op_1xtf32.initialize(arguments_1xtf32, workspace_1xtf32.get());
|
||||
CUTLASS_CHECK(status_1xtf32);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status_1xtf32 = gemm_op_1xtf32();
|
||||
CUTLASS_CHECK(status_1xtf32);
|
||||
|
||||
tensor_d_1xTF32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Run reference kernel (F64)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
cutlass::reference::device::GemmComplex(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a_F64.device_ref(),
|
||||
TransformA,
|
||||
tensor_b_F64.device_ref(),
|
||||
TransformB,
|
||||
beta,
|
||||
tensor_c_F64.device_ref(),
|
||||
tensor_d_F64.device_ref(),
|
||||
cutlass::complex<double>(0.f));
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d_F64.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Run reference kernel (F32)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
cutlass::reference::device::GemmComplex(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a_F32.device_ref(),
|
||||
TransformA,
|
||||
tensor_b_F32.device_ref(),
|
||||
TransformB,
|
||||
beta,
|
||||
tensor_c_F32.device_ref(),
|
||||
tensor_d_F32.device_ref(),
|
||||
cutlass::complex<float>(0.f));
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d_F32.sync_host();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/////// Compute l2 norms
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// l2 norm 3xTF32 vs F64
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutOutput> tensor_d_3xTF32_in_F64(problem_size.mn());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view());
|
||||
|
||||
result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view());
|
||||
|
||||
// l2 norm 1xTF32 vs F64
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutOutput> tensor_d_1xTF32_in_F64(problem_size.mn());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view());
|
||||
|
||||
result.l2_norm_1xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view());
|
||||
|
||||
// l2 norm F32 vs F64
|
||||
cutlass::HostTensor<cutlass::complex<double>, LayoutOutput> tensor_d_F32_in_F64(problem_size.mn());
|
||||
cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view());
|
||||
|
||||
result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view());
|
||||
|
||||
results.push_back(result);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
|
||||
std::cout << std::fixed;
|
||||
std::cout.precision(4);
|
||||
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
|
||||
std::cout.precision(2);
|
||||
std::cout << "GFLOPs: " << result.gflops << std::endl;
|
||||
std::cout << "Normalized L2 norm of" << std::endl;
|
||||
std::cout.precision(8);
|
||||
std::cout << std::scientific
|
||||
<< " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl
|
||||
<< " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl
|
||||
<< " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, const char **argv) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool result = true;
|
||||
|
||||
if (options.benchmark) {
|
||||
for (int k = 4; k <= 65536; k *= 2) {
|
||||
|
||||
options.problem_size[2] = k;
|
||||
|
||||
printf("Gemm problem size: %d x %d x %d\n", \
|
||||
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
|
||||
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
result &= run(options);
|
||||
}
|
||||
} else {
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
result = run(options);
|
||||
}
|
||||
|
||||
if (!result) return -1;
|
||||
|
||||
std::cout << std::endl << "CSV results" << std::endl;
|
||||
Result::print_csv_header();
|
||||
for(auto &r : results)
|
||||
r.print_csv_row();
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm
|
||||
29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu
|
||||
)
|
||||
|
||||
@ -93,6 +93,13 @@ foreach(EXAMPLE
|
||||
20_simt_canonical
|
||||
21_quaternion_gemm
|
||||
22_quaternion_conv
|
||||
23_ampere_gemm_operand_reduction_fusion
|
||||
24_gemm_grouped
|
||||
25_ampere_fprop_mainloop_fusion
|
||||
26_ampere_wgrad_mainloop_fusion
|
||||
27_ampere_3xtf32_fast_accurate_tensorop_gemm
|
||||
28_ampere_3xtf32_fast_accurate_tensorop_fprop
|
||||
29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -225,6 +225,34 @@ struct global_store;
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template <typename AccessType>
|
||||
struct global_store<AccessType, 64> {
|
||||
CUTLASS_DEVICE
|
||||
global_store(AccessType const &D, void *ptr, bool pred_guard) {
|
||||
uint4 const *data = reinterpret_cast<uint4 const *>(&D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %5, 0;\n"
|
||||
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
|
||||
" @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n"
|
||||
" @p st.global.v4.u32 [%11], {%12, %13, %14, %15};\n"
|
||||
" @p st.global.v4.u32 [%16], {%17, %18, %19, %20};\n"
|
||||
"}\n"
|
||||
:
|
||||
: "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
|
||||
"r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16),
|
||||
"r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w),
|
||||
"l"(((uint8_t *)ptr) + 32),
|
||||
"r"(data[2].x), "r"(data[2].y), "r"(data[2].z), "r"(data[2].w),
|
||||
"l"(((uint8_t *)ptr) + 48),
|
||||
"r"(data[3].x), "r"(data[3].y), "r"(data[3].z), "r"(data[2].w));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename AccessType>
|
||||
struct global_store<AccessType, 32> {
|
||||
CUTLASS_DEVICE
|
||||
|
||||
@ -282,5 +282,52 @@ inline __device__ void ldsm<layout::ColumnMajor, 4>(
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename AccessType, int Bytes>
|
||||
struct shared_load_op {
|
||||
CUTLASS_DEVICE
|
||||
shared_load_op(AccessType &D, void const *ptr) {
|
||||
D = *reinterpret_cast<AccessType const *>(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType>
|
||||
CUTLASS_DEVICE void shared_load(AccessType &D, void const *ptr) {
|
||||
shared_load_op<AccessType, int(sizeof(AccessType))>(D, ptr);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename AccessType>
|
||||
struct shared_load_op<AccessType, 16> {
|
||||
CUTLASS_DEVICE
|
||||
shared_load_op(AccessType &D, void const *ptr) {
|
||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||
|
||||
uint4 v;
|
||||
asm volatile ("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" :
|
||||
"=r"(v.x), "=r"(v.y), "=r"(v.z), "=r"(v.w) : "r"(addr));
|
||||
|
||||
D = reinterpret_cast<AccessType const &>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename AccessType>
|
||||
struct shared_load_op<AccessType, 8> {
|
||||
CUTLASS_DEVICE
|
||||
shared_load_op(AccessType &D, void const *ptr) {
|
||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||
|
||||
uint2 v;
|
||||
asm volatile ("ld.shared.v2.b32 {%0, %1}, [%2];" :
|
||||
"=r"(v.x), "=r"(v.y) : "r"(addr));
|
||||
|
||||
D = reinterpret_cast<AccessType const &>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
||||
@ -68,6 +68,18 @@ template <
|
||||
CacheOperation::Kind cache_op = CacheOperation::Always>
|
||||
struct cp_async_zfill;
|
||||
|
||||
/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate
|
||||
/// the entire transfer, nans (0x7eff) are written to SMEM if the guard predicate is false.
|
||||
///
|
||||
/// LDGSTS
|
||||
///
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes,
|
||||
/// Cache operation
|
||||
CacheOperation::Kind cache_op = CacheOperation::Always>
|
||||
struct cp_async_nan;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
@ -150,6 +162,48 @@ struct cp_async_zfill<SizeInBytes, CacheOperation::Always> {
|
||||
}
|
||||
};
|
||||
|
||||
__device__ __constant__ uint4 OOB_NAN_F16x8 = {0x7eff7eff, 0x7eff7eff,
|
||||
0x7eff7eff, 0x7eff7eff};
|
||||
|
||||
/// Partial specialization
|
||||
template <>
|
||||
struct cp_async_nan<16, CacheOperation::Always> {
|
||||
static int const kSizeInBytes = 16;
|
||||
|
||||
/// Copy with nan fill
|
||||
CUTLASS_DEVICE
|
||||
cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
#if CUDA_CP_ASYNC_ACTIVATED
|
||||
|
||||
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
" @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n"
|
||||
#else
|
||||
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
|
||||
#endif
|
||||
" @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n"
|
||||
"}\n"
|
||||
:
|
||||
: "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr),
|
||||
"n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z),
|
||||
"r"(OOB_NAN_F16x8.w));
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(smem_ptr);
|
||||
CUTLASS_UNUSED(global_ptr);
|
||||
CUTLASS_UNUSED(pred_guard);
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
|
||||
@ -62,6 +62,16 @@ struct OpMultiplyAddFastF16;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tag indicating the input is converted to 2 (big and small) TF32 components
|
||||
// Perform 3xTF32 or 4xTF32 for every F32 output element
|
||||
struct OpMultiplyAddFastF32;
|
||||
|
||||
/// Tag indicating the input is converted to 2 (big and small) TF32 components
|
||||
// Perform 3xTF32 or 4xTF32 for every complex<F32> output element
|
||||
struct OpMultiplyAddComplexFastF32;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tag indicating the complex multiply-add operation
|
||||
struct OpMultiplyAddComplex;
|
||||
|
||||
|
||||
@ -1170,7 +1170,6 @@ struct Mma<
|
||||
) const {
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
using WmmaFragmentA = nvcuda::wmma::fragment<
|
||||
nvcuda::wmma::matrix_a,
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "mma.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@ -2109,7 +2110,6 @@ struct Mma<
|
||||
|
||||
int const *C = reinterpret_cast<int const *>(&c);
|
||||
int *D = reinterpret_cast<int *>(&d);
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, "
|
||||
"{%4,%5,%6,%7}, "
|
||||
@ -2119,8 +2119,10 @@ struct Mma<
|
||||
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
|
||||
assert(0);
|
||||
#endif
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ struct Wmma<
|
||||
platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
|
||||
platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
|
||||
platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
|
||||
"Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");
|
||||
"Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x32x16, and 32x8x16");
|
||||
|
||||
// check supported wmma output data type for the given multiplicand data types
|
||||
static_assert(
|
||||
|
||||
@ -76,7 +76,7 @@ struct Wmma<
|
||||
platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
|
||||
platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
|
||||
platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
|
||||
"Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");
|
||||
"Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16");
|
||||
|
||||
|
||||
// Wmma Fragment
|
||||
@ -157,7 +157,7 @@ struct Wmma<
|
||||
platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
|
||||
platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
|
||||
platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
|
||||
"Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");
|
||||
"Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16");
|
||||
|
||||
// Wmma Fragment
|
||||
using FragmentA = nvcuda::wmma::fragment<
|
||||
|
||||
@ -338,6 +338,7 @@ private:
|
||||
|
||||
public:
|
||||
|
||||
#if 0
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array() { }
|
||||
|
||||
@ -348,6 +349,7 @@ public:
|
||||
storage[i] = x.storage[i];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Efficient clear method
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -395,6 +395,7 @@ private:
|
||||
|
||||
public:
|
||||
|
||||
#if 0
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array() { }
|
||||
|
||||
@ -405,6 +406,7 @@ public:
|
||||
storage[i] = x.storage[i];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Efficient clear method
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -47,6 +47,7 @@
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
@ -485,6 +486,27 @@ int strided_dgrad_tile_m_per_filter(
|
||||
return tile_m_per_filter;
|
||||
}
|
||||
|
||||
// Computes starting Dx coord (h, w) for given starting filter postion
|
||||
CUTLASS_HOST_DEVICE
|
||||
void strided_dgrad_starting_coords(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int r, int s,
|
||||
int &start_h, int &start_w) {
|
||||
|
||||
// function locals for remainder by fast divmod
|
||||
int pad_h_rem_, pad_w_rem_;
|
||||
|
||||
// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
|
||||
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
|
||||
int r_ = std::abs(problem_size.stride_h - (pad_h_rem_ - r));
|
||||
stride_h_divmod.divmod(start_h, r_);
|
||||
|
||||
//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
|
||||
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
|
||||
int s_ = std::abs(problem_size.stride_w - (pad_w_rem_ - s));
|
||||
stride_w_divmod.divmod(start_w, s_);
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -105,6 +105,18 @@ public:
|
||||
return status;
|
||||
}
|
||||
|
||||
static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
if (kConvolutionalOperator == conv::Operator::kFprop) {
|
||||
if (args.problem_size.K % kAlignmentC)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
} else if (kConvolutionalOperator == conv::Operator::kDgrad) {
|
||||
if (args.problem_size.C % kAlignmentC)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
} else if (kConvolutionalOperator == conv::Operator::kWgrad) {
|
||||
if (args.problem_size.C % kAlignmentC)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
// check for unsupported problem sizes for strided dgrad implementation
|
||||
if (kConvolutionalOperator == conv::Operator::kDgrad &&
|
||||
kStrideSupport == conv::StrideSupport::kStrided) {
|
||||
@ -217,14 +229,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
cutlass::Kernel<ImplicitGemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
|
||||
262
include/cutlass/conv/device/implicit_gemm_convolution_fusion.h
Normal file
262
include/cutlass/conv/device/implicit_gemm_convolution_fusion.h
Normal file
@ -0,0 +1,262 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Template for device-level fused activation's scale+bias+relu and Implicit GEMM Convolution
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename ImplicitGemmFusionKernel_>
|
||||
class ImplicitGemmConvolutionFusion {
|
||||
public:
|
||||
|
||||
using ImplicitGemmFusionKernel = ImplicitGemmFusionKernel_;
|
||||
|
||||
using ElementA = typename ImplicitGemmFusionKernel::ElementA;
|
||||
using LayoutA = typename ImplicitGemmFusionKernel::LayoutA;
|
||||
using ElementB = typename ImplicitGemmFusionKernel::ElementB;
|
||||
using LayoutB = typename ImplicitGemmFusionKernel::LayoutB;
|
||||
|
||||
// using ElementScaleBias = typename ImplicitGemmFusionKernel::ElementScaleBias;
|
||||
// using LayoutScaleBias = typename ImplicitGemmFusionKernel::LayoutScaleBias;
|
||||
|
||||
using ElementC = typename ImplicitGemmFusionKernel::ElementC;
|
||||
using LayoutC = typename ImplicitGemmFusionKernel::LayoutC;
|
||||
using ElementAccumulator = typename ImplicitGemmFusionKernel::ElementAccumulator;
|
||||
using ElementCompute = typename ImplicitGemmFusionKernel::ElementCompute;
|
||||
using OperatorClass = typename ImplicitGemmFusionKernel::OperatorClass;
|
||||
using ArchTag = typename ImplicitGemmFusionKernel::ArchTag;
|
||||
using ThreadblockShape = typename ImplicitGemmFusionKernel::ThreadblockShape;
|
||||
using WarpShape = typename ImplicitGemmFusionKernel::WarpShape;
|
||||
using InstructionShape = typename ImplicitGemmFusionKernel::InstructionShape;
|
||||
using ThreadblockSwizzle = typename ImplicitGemmFusionKernel::ThreadblockSwizzle;
|
||||
using EpilogueOutputOp = typename ImplicitGemmFusionKernel::EpilogueOutputOp;
|
||||
static int const kStages = ImplicitGemmFusionKernel::kStages;
|
||||
static int const kConvDim = ImplicitGemmFusionKernel::kConvDim;
|
||||
using WarpMmaOperator = typename ImplicitGemmFusionKernel::WarpMmaOperator;
|
||||
using ArchMmaOperator = typename ImplicitGemmFusionKernel::ArchMmaOperator;
|
||||
using MathOperator = typename ImplicitGemmFusionKernel::MathOperator;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmFusionKernel::kConvolutionalOperator;
|
||||
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmFusionKernel::kIteratorAlgorithm;
|
||||
|
||||
static int const kWarpCount =
|
||||
(ThreadblockShape::kM / WarpShape::kM) *
|
||||
(ThreadblockShape::kN / WarpShape::kN) *
|
||||
(ThreadblockShape::kK / WarpShape::kK);
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename ImplicitGemmFusionKernel::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename ImplicitGemmFusionKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs Implicit GEMM
|
||||
ImplicitGemmConvolutionFusion() { }
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
// dispatch to iterators
|
||||
Status status = ImplicitGemmFusionKernel::Mma::IteratorA::can_implement(args.problem_size);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = ImplicitGemmFusionKernel::Mma::IteratorB::can_implement(args.problem_size);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(
|
||||
threadblock_swizzle.get_tiled_shape(
|
||||
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices));
|
||||
|
||||
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
||||
grid.z <= std::numeric_limits<uint16_t>::max())) {
|
||||
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices);
|
||||
|
||||
if(args.split_k_mode == SplitKMode::kParallel) {
|
||||
|
||||
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
|
||||
// The user needs to call a reduction operator to optain the final output tensor
|
||||
workspace_bytes =
|
||||
sizeof(ElementAccumulator) *
|
||||
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) *
|
||||
size_t(grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) {
|
||||
|
||||
// Split-K serial: The user workspace is used to store semaphore and serialize writing the
|
||||
// final reduced output to user's output tensor
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
if (args.problem_size.split_k_slices > 1) {
|
||||
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream);
|
||||
|
||||
if (status != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
// initialize the params structure from the arguments
|
||||
params_ = typename ImplicitGemmFusionKernel::Params(
|
||||
args,
|
||||
static_cast<int *>(workspace)
|
||||
);
|
||||
|
||||
int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<ImplicitGemmFusionKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Initializes Impicit GEMM state from arguments.
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
// update the params structure from the arguments
|
||||
params_.ptr_A = args.ref_A.data();
|
||||
params_.ptr_B = args.ref_B.data();
|
||||
params_.ptr_scale = args.ref_A_scale.data();
|
||||
params_.ptr_bias = args.ref_A_bias.data();
|
||||
params_.ptr_C = args.ref_C.data();
|
||||
params_.ptr_D = args.ref_D.data();
|
||||
params_.output_op = args.output_op;
|
||||
params_.semaphore = static_cast<int *>(workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(32 * kWarpCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage));
|
||||
|
||||
cutlass::Kernel<ImplicitGemmFusionKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -41,10 +41,12 @@
|
||||
#include "cutlass/conv/threadblock/conv2d_tile_iterator.h"
|
||||
#include "cutlass/conv/threadblock/implicit_gemm_pipelined.h"
|
||||
#include "cutlass/conv/threadblock/implicit_gemm_multistage.h"
|
||||
#include "cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h"
|
||||
#include "cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h"
|
||||
#include "cutlass/conv/kernel/implicit_gemm_convolution.h"
|
||||
#include "cutlass/conv/kernel/implicit_gemm_convolution_fusion.h"
|
||||
#include "cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h"
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -65,8 +65,12 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
> struct DefaultConv2dDgrad;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -90,7 +94,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@ -110,7 +116,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -121,24 +129,28 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -147,6 +159,11 @@ struct DefaultConv2dDgrad <
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
@ -155,7 +172,7 @@ struct DefaultConv2dDgrad <
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
@ -196,7 +213,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@ -216,7 +235,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -227,13 +248,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@ -241,13 +264,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kStrided
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@ -308,7 +333,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@ -328,7 +355,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -339,24 +368,28 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -365,6 +398,11 @@ struct DefaultConv2dDgrad <
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
@ -373,7 +411,7 @@ struct DefaultConv2dDgrad <
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
@ -414,7 +452,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@ -434,7 +474,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -445,13 +487,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@ -459,13 +503,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@ -526,7 +572,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@ -546,7 +594,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -557,23 +607,28 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -582,6 +637,11 @@ struct DefaultConv2dDgrad <
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
@ -590,7 +650,7 @@ struct DefaultConv2dDgrad <
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
@ -615,8 +675,8 @@ struct DefaultConv2dDgrad <
|
||||
>;
|
||||
};
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm Dgrad Unity
|
||||
// 2 stage pipeline
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm Dgrad Strided and
|
||||
// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
@ -631,7 +691,129 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm Dgrad Strided
|
||||
// and 2 stage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@ -651,7 +833,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -662,13 +846,15 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@ -676,12 +862,132 @@ struct DefaultConv2dDgrad <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
StrideSupport::kStrided,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm Dgrad Unity
|
||||
// 2 stage pipeline
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
StrideSupport::kUnity,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@ -744,7 +1050,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag>
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
@ -763,7 +1072,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport::kUnity
|
||||
conv::StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -848,7 +1159,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag>
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
@ -867,7 +1181,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport::kStrided
|
||||
conv::StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -955,7 +1271,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@ -975,7 +1293,9 @@ struct DefaultConv2dDgrad <
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1040,11 +1360,115 @@ struct DefaultConv2dDgrad <
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
conv::StrideSupport::kStrided
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
conv::StrideSupport::kStrided
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Analytic IteratorAlgorithm,
|
||||
@ -1063,7 +1487,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@ -1083,7 +1509,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport::kUnity
|
||||
conv::StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1169,7 +1597,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@ -1189,7 +1619,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport::kStrided
|
||||
conv::StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1257,7 +1689,6 @@ struct DefaultConv2dDgrad <
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1278,7 +1709,9 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
@ -1298,7 +1731,9 @@ struct DefaultConv2dDgrad <
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport::kUnity
|
||||
StrideSupport::kUnity,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1368,10 +1803,119 @@ struct DefaultConv2dDgrad <
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport::kStrided,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
conv::StrideSupport::kStrided
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIteratorStridedDgrad<
|
||||
cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB,
|
||||
conv::StrideSupport::kStrided
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -65,8 +65,12 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
> struct DefaultConv2dFprop;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -90,7 +94,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -109,7 +116,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -120,22 +130,26 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -144,6 +158,11 @@ struct DefaultConv2dFprop <
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
@ -152,7 +171,7 @@ struct DefaultConv2dFprop <
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
@ -195,6 +214,9 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
@ -214,7 +236,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -312,7 +337,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -331,7 +359,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -342,12 +373,14 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@ -355,12 +388,14 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@ -419,6 +454,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
@ -438,7 +476,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -540,7 +581,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -559,7 +603,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -571,24 +618,28 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -597,6 +648,11 @@ struct DefaultConv2dFprop <
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
@ -605,7 +661,7 @@ struct DefaultConv2dFprop <
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
@ -648,6 +704,9 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
@ -667,7 +726,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -757,7 +819,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -776,7 +841,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -787,13 +855,15 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@ -801,13 +871,15 @@ struct DefaultConv2dFprop <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@ -866,6 +938,9 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
@ -885,7 +960,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -979,7 +1057,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -998,7 +1079,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1084,7 +1168,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -1103,7 +1190,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1189,7 +1279,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -1208,7 +1301,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -1295,7 +1391,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dFprop <
|
||||
ElementA,
|
||||
@ -1314,7 +1413,10 @@ struct DefaultConv2dFprop <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
|
||||
351
include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h
Normal file
351
include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h
Normal file
@ -0,0 +1,351 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution
|
||||
definitions that combine threadblock-scoped matrix multiply-add with the
|
||||
appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h"
|
||||
#include "cutlass/conv/threadblock/regular_scale_bias_vector_access_iterator.h"
|
||||
#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for fused batch norm and Conv2dFprop
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementScaleBias,
|
||||
typename LayoutScaleBias,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv2dFpropFusion;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementScaleBias,
|
||||
typename LayoutScaleBias,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dFpropFusion <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementScaleBias,
|
||||
LayoutScaleBias,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using IteratorScaleBias =
|
||||
cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator<
|
||||
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
using SmemIteratorScaleBias =
|
||||
cutlass::conv::threadblock::RegularScaleBiasVectorAccessIterator<
|
||||
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using WarpIteratorScaleBias = cutlass::conv::warp::WarpIteratorScaleBias<
|
||||
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
|
||||
MmaCore::WarpCount::kK>;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmFpropFusionMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorScaleBias,
|
||||
SmemIteratorScaleBias,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
WarpIteratorScaleBias,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Optimzed IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementScaleBias,
|
||||
typename LayoutScaleBias,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dFpropFusion <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementScaleBias,
|
||||
LayoutScaleBias,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using IteratorScaleBias =
|
||||
cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator<
|
||||
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
using SmemIteratorScaleBias =
|
||||
cutlass::conv::threadblock::RegularScaleBiasVectorAccessIterator<
|
||||
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using WarpIteratorScaleBias = cutlass::conv::warp::WarpIteratorScaleBias<
|
||||
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
|
||||
MmaCore::WarpCount::kK>;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmFpropFusionMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorScaleBias,
|
||||
SmemIteratorScaleBias,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
WarpIteratorScaleBias,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -64,8 +64,12 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
>
|
||||
struct DefaultConv2dFpropWithBroadcast {
|
||||
|
||||
@ -84,7 +88,9 @@ struct DefaultConv2dFpropWithBroadcast {
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm,
|
||||
StrideSupport
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
>::Kernel;
|
||||
|
||||
// Replace epilogue
|
||||
|
||||
@ -65,8 +65,12 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
>
|
||||
struct DefaultConv2dFpropWithReduction {
|
||||
|
||||
@ -85,7 +89,9 @@ struct DefaultConv2dFpropWithReduction {
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm,
|
||||
StrideSupport
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
>::Kernel;
|
||||
|
||||
// Replace epilogue
|
||||
|
||||
@ -66,9 +66,14 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
> struct DefaultConv2dWgrad;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -93,7 +98,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@ -112,7 +120,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -123,22 +134,26 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -179,6 +194,7 @@ struct DefaultConv2dWgrad <
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Analytic IteratorAlgorithm and two
|
||||
@ -198,7 +214,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@ -217,7 +236,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -228,12 +250,14 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@ -241,12 +265,14 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@ -308,7 +334,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@ -327,7 +356,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -338,22 +370,26 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
@ -394,6 +430,7 @@ struct DefaultConv2dWgrad <
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Optimized IteratorAlgorithm and two
|
||||
@ -413,7 +450,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@ -432,7 +472,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -443,12 +486,14 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
@ -456,12 +501,14 @@ struct DefaultConv2dWgrad <
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
@ -524,7 +571,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AccessTypeA,
|
||||
int AccessTypeB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@ -543,7 +593,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AccessTypeA,
|
||||
AccessTypeB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -629,7 +682,10 @@ template <
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AccessTypeA,
|
||||
int AccessTypeB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@ -648,7 +704,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AccessTypeA,
|
||||
AccessTypeB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -732,7 +791,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AccessTypeA,
|
||||
int AccessTypeB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@ -751,7 +813,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AccessTypeA,
|
||||
AccessTypeB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -817,7 +882,6 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -838,7 +902,10 @@ template <
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AccessTypeA,
|
||||
int AccessTypeB
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
@ -857,7 +924,10 @@ struct DefaultConv2dWgrad <
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AccessTypeA,
|
||||
AccessTypeB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
@ -925,12 +995,11 @@ struct DefaultConv2dWgrad <
|
||||
>;
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
319
include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h
Normal file
319
include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h
Normal file
@ -0,0 +1,319 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_tile_iterator.h"
|
||||
#include "cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementScaleBias,
|
||||
typename LayoutScaleBias,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv2dWgradFusion;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Analytic IteratorAlgorithm and multistage
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementScaleBias,
|
||||
typename LayoutScaleBias,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dWgradFusion <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementScaleBias,
|
||||
LayoutScaleBias,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using IteratorScaleBias =
|
||||
cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator<
|
||||
cutlass::MatrixShape<1, WarpShape::kN>,
|
||||
ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmWgradFusionMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorScaleBias,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Optimized IteratorAlgorithm and multistage
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementScaleBias,
|
||||
typename LayoutScaleBias,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dWgradFusion <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementScaleBias,
|
||||
LayoutScaleBias,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using IteratorScaleBias =
|
||||
cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator<
|
||||
cutlass::MatrixShape<1, WarpShape::kN>,
|
||||
ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmWgradFusionMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorScaleBias,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -66,7 +66,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv3dDgrad;
|
||||
|
||||
@ -228,7 +228,6 @@ struct DefaultConv3dDgrad <
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
|
||||
@ -66,7 +66,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv3dFprop;
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv3dWgrad;
|
||||
|
||||
@ -501,4 +501,3 @@ struct DefaultConv3dWgrad <
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -385,7 +385,6 @@ struct ImplicitGemmConvolution {
|
||||
|
||||
semaphore.wait(threadblock_tile_idx.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
// Each split-k-slice writes to a unique tensor location
|
||||
else if (params.split_k_mode == SplitKMode::kParallel) {
|
||||
|
||||
455
include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h
Normal file
455
include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h
Normal file
@ -0,0 +1,455 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined fused activation's scale+bias+relu and Implicit GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
#include "cutlass/epilogue/threadblock/output_iterator_parameter.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
|
||||
typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem
|
||||
>
|
||||
struct ImplicitGemmConvolutionFusion {
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static Operator const kConvolutionalOperator = ConvOperator;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
|
||||
using ElementScaleBias = typename Mma::IteratorScaleBias::Element;
|
||||
using LayoutScaleBias = typename Mma::IteratorScaleBias::Layout;
|
||||
|
||||
using ElementC = typename EpilogueOutputOp::ElementOutput;
|
||||
using LayoutC = LayoutA;
|
||||
|
||||
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
|
||||
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using WarpMmaOperator = typename Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename WarpMmaOperator::Shape;
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using TensorRefA = typename Mma::IteratorA::TensorRef;
|
||||
using TensorRefB = typename Mma::IteratorB::TensorRef;
|
||||
using TensorRefScaleBias = typename Mma::IteratorScaleBias::TensorRef;
|
||||
using TensorRefC = cutlass::TensorRef<ElementC, LayoutC>;
|
||||
|
||||
/// Check iterator A and B convolution dimension are the same and
|
||||
// set device::ImplicitGemmConvolution::kConvDim
|
||||
static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim,
|
||||
"Convolution on different different dimensions is not supported");
|
||||
static int const kConvDim = Mma::IteratorA::kConvDim;
|
||||
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
/// Wgrad C stride idx for implicit gemm algorithm
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
static int const kWgradCStrideIdx =
|
||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
|
||||
/// This chooses the appropriate stride element of the C tensor.
|
||||
static int const kTensorCStrideIdx =
|
||||
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter<
|
||||
LayoutC,
|
||||
typename Epilogue::OutputTileIterator::Layout,
|
||||
TensorRefC,
|
||||
ConvOperator,
|
||||
ConvProblemSize
|
||||
>;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ConvProblemSize problem_size;
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
TensorRefScaleBias ref_scale;
|
||||
TensorRefScaleBias ref_bias;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size
|
||||
):
|
||||
problem_size(problem_size) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size,
|
||||
TensorRefA const & ref_A,
|
||||
TensorRefB const & ref_B,
|
||||
TensorRefScaleBias const & ref_scale,
|
||||
TensorRefScaleBias const & ref_bias,
|
||||
TensorRefC const & ref_C,
|
||||
TensorRefC const & ref_D,
|
||||
typename EpilogueOutputOp::Params const & output_op,
|
||||
SplitKMode const & split_k_mode = SplitKMode::kSerial
|
||||
):
|
||||
problem_size(problem_size),
|
||||
ref_A(ref_A),
|
||||
ref_B(ref_B),
|
||||
ref_scale(ref_scale),
|
||||
ref_bias(ref_bias),
|
||||
ref_C(ref_C),
|
||||
ref_D(ref_D),
|
||||
output_op(output_op),
|
||||
split_k_mode(split_k_mode)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
ConvProblemSize problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
gemm::GemmCoord implicit_gemm_problem_size;
|
||||
int swizzle_log_tile;
|
||||
int gemm_k_iterations;
|
||||
typename Mma::IteratorA::Params iterator_A;
|
||||
typename Mma::IteratorA::Element const *ptr_A;
|
||||
typename Mma::IteratorB::Params iterator_B;
|
||||
typename Mma::IteratorB::Element const *ptr_B;
|
||||
typename Mma::IteratorScaleBias::Params iterator_scale_bias;
|
||||
typename Mma::IteratorScaleBias::Element const *ptr_scale;
|
||||
typename Mma::IteratorScaleBias::Element const *ptr_bias;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_C;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_C;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_D;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
int *semaphore;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): swizzle_log_tile(0), gemm_k_iterations(0) { }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
int *semaphore = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)),
|
||||
iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
|
||||
ptr_A(args.ref_A.data()),
|
||||
iterator_B(args.problem_size, args.ref_B.layout()),
|
||||
ptr_B(args.ref_B.data()),
|
||||
iterator_scale_bias(args.problem_size, args.ref_scale.layout()),
|
||||
ptr_scale(args.ref_scale.data()),
|
||||
ptr_bias(args.ref_bias.data()),
|
||||
iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)),
|
||||
ptr_C(args.ref_C.data()),
|
||||
iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)),
|
||||
ptr_D(args.ref_D.data()),
|
||||
output_op(args.output_op),
|
||||
semaphore(semaphore),
|
||||
split_k_mode(args.split_k_mode)
|
||||
{
|
||||
gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
implicit_gemm_problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices);
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ImplicitGemmConvolutionFusion() { }
|
||||
|
||||
/// Executes one ImplicitGEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A operand
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.iterator_A,
|
||||
params.problem_size,
|
||||
params.ptr_A,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
threadblock_tile_idx.k() * Mma::Shape::kK
|
||||
)
|
||||
);
|
||||
|
||||
// Construct iterators to B operand
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.iterator_B,
|
||||
params.problem_size,
|
||||
params.ptr_B,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.k() * Mma::Shape::kK,
|
||||
threadblock_tile_idx.n() * Mma::Shape::kN
|
||||
)
|
||||
);
|
||||
|
||||
// Construct iterators to A scale/bias vector
|
||||
typename Mma::IteratorScaleBias iterator_scale_bias(
|
||||
params.iterator_scale_bias,
|
||||
params.problem_size,
|
||||
params.ptr_scale,
|
||||
params.ptr_bias,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
0, (kConvolutionalOperator == conv::Operator::kFprop) ?
|
||||
(threadblock_tile_idx.k() * Mma::Shape::kK) :
|
||||
// Wgrad
|
||||
(threadblock_tile_idx.n() * Mma::Shape::kN)
|
||||
)
|
||||
);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(params.gemm_k_iterations, accumulators, iterator_A,
|
||||
iterator_B, iterator_scale_bias, accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
// Construct the semaphore.
|
||||
int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m();
|
||||
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// Compute logical position within grid
|
||||
threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
threadblock_tile_idx.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.iterator_D,
|
||||
params.ptr_D,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator reading from source accumulator tensor
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.iterator_C,
|
||||
params.ptr_C,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_idx.k()) {
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_idx.k());
|
||||
|
||||
}
|
||||
// Each split-k-slice writes to a unique tensor location
|
||||
else if (params.split_k_mode == SplitKMode::kParallel) {
|
||||
iterator_D.add_pointer_offset(threadblock_tile_idx.k() *
|
||||
cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size));
|
||||
}
|
||||
|
||||
// Run efficient epilogue
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_idx.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -199,7 +199,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
struct Params {
|
||||
ConvProblemSize problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
FastDivmod filter_s_divmod;
|
||||
FastDivmod stride_h_divmod;
|
||||
FastDivmod stride_w_divmod;
|
||||
int gemm_k_iterations;
|
||||
typename Mma::IteratorA::Params iterator_A;
|
||||
typename Mma::IteratorA::Element const *ptr_A;
|
||||
@ -227,7 +228,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
int *semaphore = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
filter_s_divmod(args.problem_size.stride_w),
|
||||
stride_h_divmod(args.problem_size.stride_h),
|
||||
stride_w_divmod(args.problem_size.stride_w),
|
||||
iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
|
||||
ptr_A(args.ref_A.data()),
|
||||
iterator_B(args.problem_size, args.ref_B.layout()),
|
||||
@ -297,7 +299,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
// int start_s = filter_tile_m % (params.problem_size.stride_w);
|
||||
|
||||
int start_r, start_s;
|
||||
params.filter_s_divmod(start_r, start_s, filter_tile_m);
|
||||
params.stride_w_divmod(start_r, start_s, filter_tile_m);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
@ -320,6 +322,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
params.problem_size,
|
||||
params.ptr_A,
|
||||
thread_idx,
|
||||
params.stride_h_divmod, params.stride_w_divmod,
|
||||
start_r, start_s,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
@ -386,6 +389,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
params.ptr_D,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
params.stride_h_divmod, params.stride_w_divmod,
|
||||
start_r, start_s,
|
||||
threadblock_offset
|
||||
);
|
||||
@ -396,6 +400,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
params.ptr_C,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
params.stride_h_divmod, params.stride_w_divmod,
|
||||
start_r, start_s,
|
||||
threadblock_offset
|
||||
);
|
||||
@ -418,7 +423,6 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
|
||||
semaphore.wait(threadblock_tile_idx.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
// Each split-k-slice writes to a unique tensor location
|
||||
else if (params.split_k_mode == SplitKMode::kParallel) {
|
||||
|
||||
@ -439,7 +439,6 @@ struct ImplicitGemmConvolutionWithFusedEpilogue {
|
||||
|
||||
semaphore.wait(threadblock_tile_idx.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
// Each split-k-slice writes to a unique tensor location
|
||||
else if (params.split_k_mode == SplitKMode::kParallel) {
|
||||
|
||||
@ -59,7 +59,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic;
|
||||
|
||||
@ -70,13 +71,15 @@ class Conv2dDgradFilterTileAccessIteratorAnalytic;
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kStrided
|
||||
conv::StrideSupport::kStrided,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
@ -88,7 +91,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -97,7 +100,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or larger.");
|
||||
|
||||
@ -113,6 +121,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
// For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension
|
||||
@ -160,8 +169,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -199,9 +210,9 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int c = offset_c_[iteration_contiguous_];
|
||||
int k = offset_k_[iteration_strided_];
|
||||
|
||||
int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(k, filter_r_, filter_s_, c);
|
||||
}
|
||||
|
||||
@ -228,11 +239,18 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
@ -247,7 +265,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
@ -261,13 +279,15 @@ public:
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
conv::StrideSupport::kUnity,
|
||||
AccessType_
|
||||
>{
|
||||
public:
|
||||
|
||||
@ -279,7 +299,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -288,7 +308,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or larger.");
|
||||
|
||||
@ -304,6 +329,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
// For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension
|
||||
@ -346,8 +372,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -381,8 +409,8 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int c = offset_c_[iteration_contiguous_];
|
||||
int k = offset_k_[iteration_strided_];
|
||||
int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(k, filter_r_, filter_s_, c);
|
||||
}
|
||||
@ -404,12 +432,17 @@ public:
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -429,7 +462,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
@ -444,5 +477,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -60,7 +60,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorOptimized;
|
||||
|
||||
@ -71,13 +72,15 @@ class Conv2dDgradFilterTileAccessIteratorOptimized;
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorOptimized <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
conv::StrideSupport::kStrided,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
@ -89,7 +92,283 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params : Conv2dStridedDgradFilterIteratorOptimizedParams {
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Conv2dStridedDgradFilterIteratorOptimizedParams const &base):
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams(base) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
):
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams(
|
||||
problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::kElementsPerAccess,
|
||||
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
|
||||
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}
|
||||
) { }
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_[kAccessesPerVector];
|
||||
int filter_k_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
|
||||
int start_r_;
|
||||
int start_s_;
|
||||
|
||||
int64_t reset_bytes_s_;
|
||||
int64_t reset_bytes_r_;
|
||||
|
||||
//
|
||||
// Assertions
|
||||
//
|
||||
|
||||
// We map predicates into bits packed in this uint32_t container
|
||||
static_assert(ThreadMap::Iterations::kStrided *
|
||||
ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8,
|
||||
"Currently, the number of loads per iteration is limited by the size of the predicates container.");
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorOptimized(
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_{0},
|
||||
filter_r_(start_r),
|
||||
filter_s_(start_s),
|
||||
start_r_(start_r),
|
||||
start_s_(start_s) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.row() + thread_coord.strided();
|
||||
Index column = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0];
|
||||
reset_bytes_r_ = reset_bytes_s_ +
|
||||
(problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided;
|
||||
int filter_c = column + c * ThreadMap::Delta::kContiguous;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
|
||||
uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_[v] |= (pred << pred_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TensorCoord coord{filter_k_, filter_r_, filter_s_, column};
|
||||
|
||||
pointer_ += params_.layout(coord) * sizeof_bits<Element>::value / 8;
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
int next_idx = 0;
|
||||
LongIndex reset_bytes = params_.reset_bytes;
|
||||
|
||||
// Move filter_s by stride_w
|
||||
filter_s_ += problem_size_.stride_w;
|
||||
if (filter_s_ >= problem_size_.S) {
|
||||
|
||||
// Restore filter_s
|
||||
filter_s_ = start_s_;
|
||||
|
||||
// Move filter_r by stride_h
|
||||
filter_r_ += problem_size_.stride_h;
|
||||
|
||||
bool check = (filter_r_ < problem_size_.R);
|
||||
|
||||
filter_r_ = check ? filter_r_ : start_r_;
|
||||
next_idx = check ? 1 : 2;
|
||||
reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_);
|
||||
}
|
||||
|
||||
// offset pointers by offset_bytes
|
||||
pointer_ += (params_.inc_next[next_idx] - reset_bytes);
|
||||
|
||||
if (next_idx == 2) {
|
||||
filter_k_ += params_.filter_k_delta;
|
||||
}
|
||||
|
||||
// Clear predicates if needed
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
predicates_[v] = (predicates_[v] & (~kClearMask));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the filter tensor W
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
|
||||
return (predicates_[iteration_vector_] & (1u << pred_idx));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
return reinterpret_cast<AccessType const *>(pointer_ +
|
||||
iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
|
||||
// Move to the next K coordinate within the tile
|
||||
pointer_ += params_.inc_next_strided;
|
||||
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad
|
||||
// on problem sizes with stride = {1x1}
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorOptimized <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -98,7 +377,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
@ -139,9 +423,10 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_;
|
||||
uint32_t predicates_[kAccessesPerVector];
|
||||
int filter_rs_;
|
||||
int filter_k_;
|
||||
|
||||
@ -167,7 +452,7 @@ public:
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
predicates_{0},
|
||||
filter_rs_(0),
|
||||
filter_k_(0) {
|
||||
|
||||
@ -184,11 +469,15 @@ public:
|
||||
int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided;
|
||||
int filter_c = column + c * ThreadMap::Delta::kContiguous;
|
||||
|
||||
uint32_t pred = ((filter_k < problem_size_.K && filter_c < problem_size_.C) ? 1u : 0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_ |= (pred << pred_idx);
|
||||
uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_[v] |= (pred << pred_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -202,8 +491,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -232,7 +523,11 @@ public:
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
predicates_ = (predicates_ & (~kClearMask));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
predicates_[v] = (predicates_[v] & (~kClearMask));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -243,19 +538,25 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
|
||||
return (predicates_ & (1u << pred_idx));
|
||||
return (predicates_[iteration_vector_] & (1u << pred_idx));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
return reinterpret_cast<AccessType const *>(pointer_ +
|
||||
iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8);
|
||||
iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -280,7 +581,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
@ -295,5 +596,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -59,7 +59,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -69,13 +70,15 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic;
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kStrided
|
||||
conv::StrideSupport::kStrided,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
@ -86,7 +89,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -95,7 +98,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@ -118,6 +126,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_;
|
||||
@ -130,7 +139,6 @@ private:
|
||||
int offset_p_[ThreadMap::Iterations::kStrided];
|
||||
int offset_q_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -139,6 +147,7 @@ public:
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
|
||||
):
|
||||
@ -164,9 +173,12 @@ public:
|
||||
}
|
||||
|
||||
// Starting h, w positions for filter position in gemm_k=0
|
||||
int start_h = std::abs((problem_size_.pad_h - filter_r) % problem_size_.stride_h);
|
||||
int start_w = std::abs((problem_size_.pad_w - filter_s) % problem_size_.stride_w);
|
||||
|
||||
int start_h, start_w;
|
||||
strided_dgrad_starting_coords(
|
||||
problem_size_,
|
||||
stride_h_divmod, stride_w_divmod,
|
||||
filter_r, filter_s,
|
||||
start_h, start_w);
|
||||
|
||||
// Effective P and Q for filter position required for remapping NHW rows
|
||||
int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h;
|
||||
@ -206,8 +218,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -254,11 +268,13 @@ public:
|
||||
p += (conv_sign * (filter_r_ / problem_size_.stride_h));
|
||||
q += (conv_sign * (filter_s_ / problem_size_.stride_w));
|
||||
|
||||
int k = filter_k_ + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(
|
||||
n,
|
||||
p,
|
||||
q,
|
||||
filter_k_);
|
||||
k);
|
||||
}
|
||||
|
||||
|
||||
@ -288,11 +304,18 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
@ -307,14 +330,14 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv2dDgradOutputGradientTileAccessIteratorAnalytic for unity strides can be optimized by
|
||||
@ -322,13 +345,15 @@ public:
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
conv::StrideSupport::kUnity,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
@ -339,7 +364,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -348,7 +373,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@ -388,6 +418,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_;
|
||||
@ -439,8 +470,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -486,11 +519,12 @@ public:
|
||||
int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h) / problem_size_.stride_h;
|
||||
int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w) / problem_size_.stride_w;
|
||||
|
||||
return TensorCoord(n, p, q, filter_k_);
|
||||
int k = filter_k_ + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(n, p, q, k);
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// Returns true if the current coordinate is within the output tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
@ -516,6 +550,12 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -541,7 +581,7 @@ public:
|
||||
}
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
@ -549,7 +589,9 @@ public:
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -61,11 +61,386 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Conv2dDgradOutputGradientTileAccessIteratorOptimized strided dgrad needs special handling
|
||||
// to skip MMAs (Dx = Dy * w) on invalid filter positions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kStrided,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
using Mask = uint64_t;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Simpligying assertions
|
||||
//
|
||||
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dStridedDgradOutputGradientIteratorOptimizedParams;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
|
||||
// One pointer per access
|
||||
char const *pointer_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
int filter_k_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int start_r_;
|
||||
int start_s_;
|
||||
int64_t reset_bytes_s_;
|
||||
int64_t reset_bytes_r_;
|
||||
|
||||
Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorOptimized(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
filter_k_(0),
|
||||
filter_r_(start_r),
|
||||
filter_s_(start_s),
|
||||
start_r_(start_r),
|
||||
start_s_(start_s) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0];
|
||||
|
||||
reset_bytes_r_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0] +
|
||||
(problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1];
|
||||
|
||||
int offset_n[ThreadMap::Iterations::kStrided];
|
||||
int offset_p[ThreadMap::Iterations::kStrided];
|
||||
int offset_q[ThreadMap::Iterations::kStrided];
|
||||
|
||||
int filter_r = filter_r_;
|
||||
int filter_s = filter_s_;
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
filter_r = (problem_size_.R - 1 - filter_r);
|
||||
filter_s = (problem_size_.S - 1 - filter_s);
|
||||
}
|
||||
|
||||
// Starting h, w positions for filter position in gemm_k=0
|
||||
int start_h, start_w;
|
||||
strided_dgrad_starting_coords(
|
||||
problem_size_,
|
||||
stride_h_divmod, stride_w_divmod,
|
||||
filter_r, filter_s,
|
||||
start_h, start_w);
|
||||
|
||||
|
||||
// Effective starting P and Q for filter position required for remapping NHW rows
|
||||
int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h;
|
||||
int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
pointer_[s] = reinterpret_cast<char const *>(ptr);
|
||||
|
||||
int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter;
|
||||
|
||||
// (STEP 1) [reorder NHW rows to start with same filter positions]
|
||||
offset_n[s] = offset_npq / (P * Q);
|
||||
int residual = offset_npq % (P * Q);
|
||||
|
||||
int p = (residual / Q);
|
||||
int q = (residual % Q);
|
||||
|
||||
int mapped_h = (start_h + p * problem_size_.stride_h);
|
||||
int mapped_w = (start_w + q * problem_size_.stride_w);
|
||||
|
||||
// Access (p, q) coordinates for Dy tensor for filter position in gemm_k=0
|
||||
// note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are ensured to be
|
||||
// divisible by stride_h and stride_w
|
||||
offset_p[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h;
|
||||
offset_q[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w;
|
||||
|
||||
// Intialize pointers for gemm_k=0
|
||||
TensorCoord coord{offset_n[s], offset_p[s], offset_q[s], filter_k_};
|
||||
|
||||
pointer_[s] += params_.layout(coord) * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
//
|
||||
// Precompute mask predicates
|
||||
//
|
||||
clear_mask();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int r = start_r; r < problem_size_.R; r += problem_size_.stride_h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) {
|
||||
|
||||
int p = offset_p[s_idx] ;
|
||||
|
||||
p += (params_.conv_sign * (r / problem_size_.stride_h));
|
||||
|
||||
bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][0] |= (pred << r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for(int s = start_s; s < problem_size_.S; s += problem_size_.stride_w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) {
|
||||
|
||||
int q = offset_q[s_idx];
|
||||
q += (params_.conv_sign * (s / problem_size_.stride_w));
|
||||
|
||||
bool pred = (q >=0 && q < problem_size_.Q);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][1] |= (pred << s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size.K);
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) {
|
||||
return Params(problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn});
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/// Adds a pointer offset in units of element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_byte_offset_(LongIndex byte_offset, LongIndex byte_reset = 0) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
pointer_[s] += byte_offset - byte_reset;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
add_byte_offset_(pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
int next_idx = 0;
|
||||
int64_t reset_bytes = 0;
|
||||
|
||||
// Move filter_s by stride_w
|
||||
filter_s_ += problem_size_.stride_w;
|
||||
if (filter_s_ >= problem_size_.S) {
|
||||
|
||||
// Restore filter_s
|
||||
filter_s_ = start_s_;
|
||||
|
||||
// Move filter_r by stride_h
|
||||
filter_r_ += problem_size_.stride_h;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
|
||||
next_idx = 1;
|
||||
|
||||
// Restore bytes in q coordinate (Mma in filter s dimenstion)
|
||||
reset_bytes = reset_bytes_s_;
|
||||
|
||||
} else {
|
||||
|
||||
// Restore filter_r
|
||||
filter_r_ = start_r_;
|
||||
|
||||
next_idx = 2;
|
||||
|
||||
// Restore bytes in p and q coordinate (Mma in filter s and r dimenstion)
|
||||
reset_bytes = reset_bytes_r_;
|
||||
}
|
||||
}
|
||||
|
||||
// offset pointers by offset_bytes
|
||||
add_byte_offset_(params_.inc_next[next_idx] - reset_bytes);
|
||||
|
||||
if (next_idx == 2) {
|
||||
filter_k_ += params_.filter_k_delta;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask(bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask(int v, bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the output tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
return
|
||||
(masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// Limit on filter size
|
||||
if (problem_size.R > 32 || problem_size.S > 32) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is optimized for dgrad
|
||||
// with problem stride = {1x1}
|
||||
@ -74,14 +449,16 @@ class Conv2dDgradOutputGradientTileAccessIteratorOptimized;
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
> {
|
||||
conv::StrideSupport::kUnity,
|
||||
AccessType_
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
@ -93,7 +470,7 @@ public:
|
||||
using Layout = layout::TensorNHWC;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
@ -101,7 +478,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
using Mask = uint64_t;
|
||||
|
||||
//
|
||||
@ -122,6 +504,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
|
||||
// One pointer per access
|
||||
char const *pointer_[ThreadMap::Iterations::kStrided];
|
||||
@ -131,7 +514,7 @@ private:
|
||||
int filter_s_;
|
||||
int filter_k_;
|
||||
|
||||
Index masks_[ThreadMap::Iterations::kStrided][2];
|
||||
Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];
|
||||
|
||||
public:
|
||||
|
||||
@ -199,7 +582,11 @@ public:
|
||||
int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h;
|
||||
|
||||
bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P);
|
||||
masks_[s_idx][0] |= (pred << r);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][0] |= (pred << r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,12 +603,17 @@ public:
|
||||
int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w;
|
||||
|
||||
bool pred = (q >= 0 && q < problem_size_.Q);
|
||||
masks_[s_idx][1] |= (pred << s);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][1] |= (pred << s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (filter_k_ >= problem_size.K) {
|
||||
clear_mask();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_k_ >= problem_size.K);
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
@ -267,62 +659,15 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask_(bool clear) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
|
||||
// artifact in which control flow instructions are generated. Instead, our
|
||||
// intent is to predicate the mov instructions.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][0])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][0])
|
||||
);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][1])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][1])
|
||||
);
|
||||
#else
|
||||
if (clear) {
|
||||
masks_[s][0] = 0;
|
||||
masks_[s][1] = 0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of element
|
||||
@ -357,16 +702,32 @@ public:
|
||||
filter_k_ += params_.filter_k_delta;
|
||||
}
|
||||
|
||||
clear_mask_(filter_k_ >= problem_size_.K);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
void clear_mask(bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][0] = Mask(0);
|
||||
masks_[s][1] = Mask(0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask(int v, bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
|
||||
@ -374,20 +735,25 @@ public:
|
||||
bool valid() {
|
||||
|
||||
return
|
||||
(masks_[iteration_strided_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][1] & (Index(1) << filter_s_));
|
||||
(masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]);
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
@ -414,7 +780,7 @@ public:
|
||||
}
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
|
||||
@ -60,7 +60,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dFpropActivationTileAccessIteratorAnalytic {
|
||||
public:
|
||||
@ -74,7 +75,7 @@ public:
|
||||
using Layout = Layout_;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
@ -82,7 +83,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
@ -101,6 +107,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_c_;
|
||||
@ -154,8 +161,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -200,7 +209,9 @@ public:
|
||||
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
return TensorCoord(n, h, w, filter_c_);
|
||||
int c = filter_c_ + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(n, h, w, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor X
|
||||
@ -230,6 +241,12 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -250,7 +267,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
||||
@ -60,7 +60,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dFpropActivationTileAccessIteratorOptimized {
|
||||
public:
|
||||
@ -74,7 +75,7 @@ public:
|
||||
using Layout = Layout_;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
@ -85,6 +86,11 @@ public:
|
||||
|
||||
using Mask = uint64_t;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
@ -103,6 +109,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
|
||||
// One pointer per access
|
||||
char const *pointer_[ThreadMap::Iterations::kStrided];
|
||||
@ -112,7 +119,7 @@ private:
|
||||
int filter_s_;
|
||||
int filter_c_;
|
||||
|
||||
Index masks_[ThreadMap::Iterations::kStrided][2];
|
||||
Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];
|
||||
|
||||
public:
|
||||
|
||||
@ -180,7 +187,11 @@ public:
|
||||
int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h;
|
||||
|
||||
bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H);
|
||||
masks_[s_idx][0] |= (pred << r);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][0] |= (pred << r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -197,12 +208,17 @@ public:
|
||||
int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w;
|
||||
|
||||
bool pred = (w >= 0 && w < problem_size_.W);
|
||||
masks_[s_idx][1] |= (pred << s);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
masks_[s_idx][v_idx][1] |= (pred << s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (filter_c_ >= problem_size.C) {
|
||||
clear_mask();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
@ -247,63 +263,17 @@ private:
|
||||
pointer_[s] += byte_offset;
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask_(bool clear) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
|
||||
// artifact in which control flow instructions are generated. Instead, our
|
||||
// intent is to predicate the mov instructions.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][0])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][0])
|
||||
);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][1])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][1])
|
||||
);
|
||||
#else
|
||||
if (clear) {
|
||||
masks_[s][0] = 0;
|
||||
masks_[s][1] = 0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of element
|
||||
@ -338,16 +308,32 @@ public:
|
||||
filter_c_ += params_.filter_c_delta;
|
||||
}
|
||||
|
||||
clear_mask_(filter_c_ >= problem_size_.C);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
void clear_mask(bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][0] = Mask(0);
|
||||
masks_[s][1] = Mask(0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
masks_[s][v][0] = clear ? 0 : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? 0 : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask(int v, bool clear = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][v][0] = clear ? 0 : masks_[s][v][0];
|
||||
masks_[s][v][1] = clear ? 0 : masks_[s][v][1];
|
||||
}
|
||||
}
|
||||
|
||||
@ -355,21 +341,27 @@ public:
|
||||
bool valid() {
|
||||
|
||||
return
|
||||
(masks_[iteration_strided_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][1] & (Index(1) << filter_s_));
|
||||
(masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]);
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationTileAccessIteratorOptimized &operator++() {
|
||||
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -390,7 +382,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
||||
@ -59,7 +59,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dFpropFilterTileAccessIteratorAnalytic {
|
||||
public:
|
||||
@ -72,7 +73,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -81,7 +82,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
@ -100,6 +106,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_r_;
|
||||
@ -140,8 +147,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -174,8 +183,9 @@ public:
|
||||
TensorCoord at() const {
|
||||
|
||||
int k = offset_k_[iteration_strided_];
|
||||
int c = filter_c_ + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(k, filter_r_, filter_s_, filter_c_);
|
||||
return TensorCoord(k, filter_r_, filter_s_, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor W
|
||||
@ -201,6 +211,12 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -221,7 +237,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
@ -248,5 +264,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -60,7 +60,8 @@ template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dFpropFilterTileAccessIteratorOptimized{
|
||||
public:
|
||||
@ -73,7 +74,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -82,7 +83,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
@ -127,9 +133,10 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_;
|
||||
uint32_t predicates_[kAccessesPerVector];
|
||||
int filter_rs_;
|
||||
int filter_c_;
|
||||
|
||||
@ -154,7 +161,7 @@ public:
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
predicates_{0},
|
||||
filter_rs_(0),
|
||||
filter_c_(0) {
|
||||
|
||||
@ -166,11 +173,16 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0);
|
||||
predicates_ |= (pred << s);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
predicates_[v_idx] |= (pred << s);
|
||||
}
|
||||
}
|
||||
|
||||
if (filter_c_ >= problem_size.C) {
|
||||
predicates_ = 0u;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
|
||||
}
|
||||
|
||||
pointer_ += (
|
||||
@ -183,8 +195,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -206,29 +220,42 @@ public:
|
||||
next = params_.inc_next_c;
|
||||
filter_c_ += params_.filter_c_delta;
|
||||
}
|
||||
|
||||
if (filter_c_ >= problem_size_.C) {
|
||||
predicates_ = 0;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
|
||||
}
|
||||
|
||||
pointer_ += next;
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask(int v, bool clear = true) {
|
||||
predicates_[v] = clear ? 0u : predicates_[v];
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the filter tensor W
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
return (predicates_ & (1u << iteration_strided_));
|
||||
return (predicates_[iteration_vector_] & (1u << iteration_strided_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
return reinterpret_cast<AccessType const *>(pointer_);
|
||||
return reinterpret_cast<AccessType const *>(pointer_) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -253,7 +280,7 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
||||
@ -527,6 +527,64 @@ struct Conv2dDgradOutputGradientIteratorOptimizedParams {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Strided Dgrad Optimized Dy params (layout::TensorNHWC)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
|
||||
int64_t inc_next[3]; // {next S, next R, next K}
|
||||
|
||||
int filter_k_delta; // number of logical elements to add to filter_k_
|
||||
|
||||
int tiled_rows_per_filter;
|
||||
|
||||
int conv_sign;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dStridedDgradOutputGradientIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dStridedDgradOutputGradientIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout, ///< layout object
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape
|
||||
): layout(layout) {
|
||||
|
||||
int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row());
|
||||
|
||||
tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row();
|
||||
|
||||
conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1);
|
||||
|
||||
// next S
|
||||
inc_next[0] = conv_sign * (
|
||||
layout.stride()[0] * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next R
|
||||
inc_next[1] = conv_sign * (
|
||||
layout.stride()[1] * problem_size.dilation_h
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next K
|
||||
inc_next[2] = (
|
||||
threadblock_shape.column() * problem_size.split_k_slices
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// logical offset added to internal channel counter - units are elements, not bytes
|
||||
filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices;
|
||||
}
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Dgrad Optimized w params (layout::TensorNHWC)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -584,6 +642,73 @@ struct Conv2dDgradFilterIteratorOptimizedParams {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// StridedDgrad Optimized w params (layout::TensorNHWC)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct Conv2dStridedDgradFilterIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
int RS;
|
||||
int filter_k_delta;
|
||||
|
||||
int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile
|
||||
int64_t inc_next[3]; // {next S, next R, next K}
|
||||
int64_t reset_bytes; // offset in units of bytes to move back the pointer
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dStridedDgradFilterIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout,
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
):
|
||||
layout(layout), RS(problem_size.R * problem_size.S) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8;
|
||||
|
||||
// next S
|
||||
inc_next[0] =
|
||||
( layout.stride()[0] * problem_size.stride_w
|
||||
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next R
|
||||
inc_next[1] =
|
||||
( layout.stride()[1] * problem_size.stride_h
|
||||
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next K
|
||||
inc_next[2] =
|
||||
(
|
||||
threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2]
|
||||
//- (problem_size.R * problem_size.S - 1) * layout.stride()[0]
|
||||
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// offset in units of bytes to move the pointer in backward direction
|
||||
reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
* element_size_bits / 8;
|
||||
|
||||
filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices;
|
||||
}
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters object for Conv2d WGRAD Output Gradient (dy) iterator
|
||||
struct Conv2dWgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
|
||||
@ -68,6 +68,7 @@ public:
|
||||
using Params = typename TileAccessIterator::Params;
|
||||
static int const kConvDim = TileAccessIterator::kConvDim;
|
||||
using ConvProblemSize = typename TileAccessIterator::ConvProblemSize;
|
||||
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
|
||||
|
||||
/// Fragment object to be loaded or stored
|
||||
using Fragment = cutlass::Array<
|
||||
@ -130,17 +131,22 @@ public:
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[c + s * ThreadMap::Iterations::kContiguous],
|
||||
tile_access_iterator_.get() + pointer_offset,
|
||||
tile_access_iterator_.valid()
|
||||
);
|
||||
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
||||
|
||||
++tile_access_iterator_;
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[idx],
|
||||
tile_access_iterator_.get() + pointer_offset,
|
||||
tile_access_iterator_.valid()
|
||||
);
|
||||
|
||||
++tile_access_iterator_;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -200,7 +206,27 @@ private:
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
/// Constructor (output gradient (Dy) OperandA ctor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorStridedDgrad(
|
||||
Params const ¶ms,
|
||||
ConvProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
tile_access_iterator_(
|
||||
params,
|
||||
problem_size,
|
||||
ptr,
|
||||
thread_idx,
|
||||
stride_h_divmod, stride_w_divmod,
|
||||
start_r, start_s,
|
||||
threadblock_offset) { }
|
||||
|
||||
/// Constructor (filter (w) OperandB ctor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorStridedDgrad(
|
||||
Params const ¶ms,
|
||||
@ -210,7 +236,12 @@ public:
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
tile_access_iterator_(params, problem_size, ptr, thread_idx, start_r, start_s, threadblock_offset) { }
|
||||
tile_access_iterator_(params,
|
||||
problem_size,
|
||||
ptr,
|
||||
thread_idx,
|
||||
start_r, start_s,
|
||||
threadblock_offset) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) {
|
||||
|
||||
@ -58,7 +58,8 @@ namespace threadblock {
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dWgradActivationTileAccessIteratorAnalytic {
|
||||
public:
|
||||
@ -70,7 +71,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -79,7 +80,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@ -95,6 +101,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
// Filter postion (r,s,c) in contiguous dimension stays constant for each gemm_iteration_k
|
||||
@ -147,8 +154,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -171,9 +180,22 @@ public:
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
int r, s, c;
|
||||
|
||||
int r = filter_r_[iteration_contiguous_];
|
||||
int s = filter_s_[iteration_contiguous_];
|
||||
if (kAccessesPerVector == 1) {
|
||||
/// One 128b aligned access fetching more than one element
|
||||
c = filter_c_[iteration_contiguous_];
|
||||
r = filter_r_[iteration_contiguous_];
|
||||
s = filter_s_[iteration_contiguous_];
|
||||
}
|
||||
else {
|
||||
/// Multiple access to support non-128b alignment in contiguous dimenstion
|
||||
c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C;
|
||||
int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C;
|
||||
s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S;
|
||||
int wrap_s = (filter_s_[iteration_contiguous_] + wrap_c) / problem_size_.S;
|
||||
r = filter_r_[iteration_contiguous_] + wrap_s;
|
||||
}
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r = (problem_size_.R - 1 - r);
|
||||
@ -182,14 +204,14 @@ public:
|
||||
|
||||
int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q);
|
||||
int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q);
|
||||
|
||||
|
||||
int p = residual / problem_size_.Q;
|
||||
int q = residual % problem_size_.Q;
|
||||
|
||||
|
||||
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
return TensorCoord(n, h, w, filter_c_[iteration_contiguous_]);
|
||||
|
||||
return TensorCoord(n, h, w, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activation tensor x
|
||||
@ -199,8 +221,7 @@ public:
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@ -216,6 +237,12 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -235,13 +262,12 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -57,7 +57,8 @@ namespace threadblock {
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dWgradActivationTileAccessIteratorOptimized {
|
||||
public:
|
||||
@ -69,7 +70,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -78,7 +79,12 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@ -94,6 +100,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
// Precomputed effective filter postion (r,s) in contiguous dimension stays constant for each gemm_iteration_k
|
||||
@ -151,9 +158,8 @@ public:
|
||||
s = (problem_size_.S - 1 - s);
|
||||
}
|
||||
|
||||
precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
precomputed_filter_r_[c] = -problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
precomputed_filter_s_[c] = -problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
}
|
||||
|
||||
// initialize n, p, q offset for every strided iteration
|
||||
@ -168,8 +174,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -192,6 +200,33 @@ public:
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
int r = precomputed_filter_r_[iteration_contiguous_];
|
||||
int s = precomputed_filter_s_[iteration_contiguous_];
|
||||
int c = filter_c_[iteration_contiguous_];
|
||||
|
||||
if (kAccessesPerVector > 1) {
|
||||
// This code section is only to support non-128b alignment
|
||||
// Multiple access to support non-128b alignment in contiguous dimenstion
|
||||
int wrap_c;
|
||||
params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements);
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
s -= (problem_size_.dilation_w * wrap_c);
|
||||
|
||||
int wrap_s = (s == -problem_size_.pad_w - problem_size_.dilation_w);
|
||||
s = wrap_s ? (-problem_size_.pad_w + (problem_size_.S - 1) * problem_size_.dilation_w): s;
|
||||
|
||||
r -= (problem_size_.dilation_h * wrap_s);
|
||||
|
||||
} else {
|
||||
s += (problem_size_.dilation_w * wrap_c);
|
||||
|
||||
int wrap_s = (s == (-problem_size_.pad_w + problem_size_.S * problem_size_.dilation_w));
|
||||
s = wrap_s ? -problem_size_.pad_w : s;
|
||||
|
||||
r += (problem_size_.dilation_h * wrap_s);
|
||||
}
|
||||
}
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
@ -207,10 +242,10 @@ public:
|
||||
params_.pq_divmod(n, residual, offset_npq_[iteration_strided_]);
|
||||
params_.q_divmod(p, q, residual);
|
||||
|
||||
int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_];
|
||||
int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_];
|
||||
int h = p * problem_size_.stride_h + r;
|
||||
int w = q * problem_size_.stride_w + s;
|
||||
|
||||
return TensorCoord(n, h, w, filter_c_[iteration_contiguous_]);
|
||||
return TensorCoord(n, h, w, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activation tensor x
|
||||
@ -220,8 +255,7 @@ public:
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@ -237,6 +271,12 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -256,14 +296,14 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
|
||||
@ -58,7 +58,8 @@ namespace threadblock {
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dWgradOutputGradientTileAccessIteratorAnalytic {
|
||||
public:
|
||||
@ -70,7 +71,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -80,6 +81,11 @@ public:
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@ -95,6 +101,7 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_[ThreadMap::Iterations::kContiguous];
|
||||
@ -141,8 +148,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -173,7 +182,9 @@ public:
|
||||
int p = residual / problem_size_.Q;
|
||||
int q = residual % problem_size_.Q;
|
||||
|
||||
return TensorCoord(n, p, q, filter_k_[iteration_contiguous_]);
|
||||
int k = filter_k_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(n, p, q, k);
|
||||
}
|
||||
|
||||
|
||||
@ -201,6 +212,12 @@ public:
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -220,14 +237,14 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
@ -235,5 +252,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -57,7 +57,8 @@ namespace threadblock {
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
>
|
||||
class Conv2dWgradOutputGradientTileAccessIteratorOptimized {
|
||||
public:
|
||||
@ -69,7 +70,7 @@ public:
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
@ -79,6 +80,11 @@ public:
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
@ -94,9 +100,10 @@ private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_;
|
||||
uint32_t predicates_[kAccessesPerVector];
|
||||
int filter_k_;
|
||||
int offset_npq_;
|
||||
|
||||
@ -113,7 +120,7 @@ public:
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
predicates_{0},
|
||||
filter_k_(0),
|
||||
offset_npq_(0) {
|
||||
|
||||
@ -130,13 +137,16 @@ public:
|
||||
int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous;
|
||||
int offset_npq = offset_npq_ + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
bool predicate = valid_(at_(offset_npq, filter_k));
|
||||
|
||||
uint32_t pred = (predicate ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_ |= (pred << pred_idx);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
bool predicate = valid_(at_(offset_npq, filter_k + v * AccessType::kElements));
|
||||
|
||||
uint32_t pred = (predicate ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_[v] |= (pred << pred_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,8 +173,10 @@ public:
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
@ -183,7 +195,11 @@ public:
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (offset_npq_ + s * ThreadMap::Delta::kStrided >= params_.NPQ) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
predicates_ = (predicates_ & (~kClearMask));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kAccessesPerVector; ++v) {
|
||||
predicates_[v] = (predicates_[v] & (~kClearMask));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,7 +245,7 @@ public:
|
||||
bool valid() const {
|
||||
|
||||
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
|
||||
return (predicates_ & (1u << pred_idx));
|
||||
return (predicates_[iteration_vector_] & (1u << pred_idx));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
@ -240,12 +256,18 @@ public:
|
||||
pointer_ +
|
||||
iteration_strided_ * params_.offset_next_strided +
|
||||
iteration_contiguous_ * params_.offset_next_contiguous
|
||||
);
|
||||
) + iteration_vector_;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradOutputGradientTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
@ -265,14 +287,14 @@ public:
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
@ -280,5 +302,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -79,6 +79,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or larger.");
|
||||
@ -259,5 +260,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -82,6 +82,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = StrideSupport_;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
@ -215,7 +216,8 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
|
||||
predicates_ = (predicates_ & (~kClearMask));
|
||||
}
|
||||
}
|
||||
@ -279,5 +281,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -93,6 +93,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
@ -326,11 +327,11 @@ public:
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -86,7 +86,7 @@ public:
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
using Coord3D = Coord<3>;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
using Mask = uint64_t;
|
||||
|
||||
//
|
||||
@ -401,7 +401,6 @@ public:
|
||||
}
|
||||
|
||||
clear_mask_(filter_k_ >= problem_size_.K);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -81,6 +81,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
|
||||
@ -82,7 +82,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
using Mask = uint64_t;
|
||||
|
||||
//
|
||||
|
||||
@ -80,6 +80,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
|
||||
@ -82,6 +82,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
@ -154,7 +155,7 @@ public:
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
predicates_{0},
|
||||
filter_trs_(0),
|
||||
filter_c_(0) {
|
||||
|
||||
|
||||
@ -79,6 +79,8 @@ public:
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
|
||||
@ -78,7 +78,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
|
||||
@ -0,0 +1,787 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a multistage threadblock-scoped fused activation's
|
||||
scale+bias+relu and Implicit GEMM Convolution kernel.
|
||||
|
||||
The original implicit gemm will store out-of-bound data as zeroes in the
|
||||
shared memory because zeros into the tensor core, zeroes out of the tensor
|
||||
cores. The result is remained the same. When fusing scale+bias+relu
|
||||
into the mainloop, it is no longer true because
|
||||
|
||||
0 x scale + bias = bias
|
||||
|
||||
which is no longer always 0. So, instead of storing zeroes, this fused
|
||||
kernel stores the out-of-bound data as a special NaN (0x7eff), when applying
|
||||
scale+bias+relu, the code is like
|
||||
|
||||
if (data == 0x7eff)
|
||||
data = 0;
|
||||
else
|
||||
data = scale+bias+relu(data, scale, bias);
|
||||
|
||||
See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the
|
||||
elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h"
|
||||
#include "cutlass/conv/warp/scale_bias_relu_transform.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Element type of scale and bias vectors
|
||||
typename ElementScaleBias_,
|
||||
/// Layout of scale and bias vectors
|
||||
typename LayoutScaleBias_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// WarpIterator to load Scale or Bias vector from the shared memory
|
||||
typename WarpIteratorScaleBias_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaFpropFusionBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Element type of scale and bias vectors
|
||||
using ElementScaleBias = ElementScaleBias_;
|
||||
|
||||
/// Layout of scale and bias vectors
|
||||
using LayoutScaleBias = LayoutScaleBias_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
///< WarpIterator to load Scale or Bias vector from the shared memory
|
||||
using WarpIteratorScaleBias = WarpIteratorScaleBias_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = cutlass::gemm::GemmShape<Shape::kM / WarpGemm::kM,
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the scale and bias vectors
|
||||
using TensorRefScaleBias = TensorRef<ElementScaleBias, LayoutScaleBias>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages +
|
||||
Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the A scale and bias vectors in shared memory
|
||||
using ShapeScaleBias =
|
||||
MatrixShape<1 + Policy::SmemPaddingA::kRow,
|
||||
2 * Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB =
|
||||
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer for A operand Scale and Bias
|
||||
AlignedBuffer<ElementScaleBias, ShapeScaleBias::kCount> operand_A_scale_bias;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the A scale and bias vectors
|
||||
CUTLASS_DEVICE
|
||||
static LayoutScaleBias LayoutScaleBias() {
|
||||
return LayoutScaleBias::packed(
|
||||
{ShapeScaleBias::kRow, ShapeScaleBias::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() {
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand Scale vector
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefScaleBias operand_A_scale_bias_ref() {
|
||||
return TensorRefScaleBias{operand_A_scale_bias.data(), LayoutScaleBias()};
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand scale and bias vector
|
||||
/// from shared memory
|
||||
WarpIteratorScaleBias warp_tile_iterator_A_scale_bias_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
MmaFpropFusionBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_A_scale_bias_(
|
||||
shared_storage.operand_A_scale_bias_ref(), lane_idx),
|
||||
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorScaleBias_,
|
||||
/// Iterates over vectors of scale and bias vector in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorScaleBias_,
|
||||
/// Cache operation for scale/bias operand
|
||||
cutlass::arch::CacheOperation::Kind CacheOpScaleBias,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// WarpIterator to load Scale or Bias vector from the shared memory
|
||||
typename WarpIteratorScaleBias_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class ImplicitGemmFpropFusionMultistage
|
||||
: public MmaFpropFusionBase<Shape_, typename IteratorScaleBias_::Element,
|
||||
typename IteratorScaleBias_::Layout, Policy_,
|
||||
WarpIteratorScaleBias_, Stages> {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorScaleBias = IteratorScaleBias_;
|
||||
///< WarpIterator to load Scale or Bias vector from the shared memory
|
||||
using WarpIteratorScaleBias = WarpIteratorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
///< Base class
|
||||
using Base = MmaFpropFusionBase<Shape_, typename IteratorScaleBias::Element,
|
||||
typename IteratorScaleBias::Layout, Policy_,
|
||||
WarpIteratorScaleBias, Stages>;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScaleBias = SmemIteratorScaleBias_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpScaleBias =
|
||||
CacheOpScaleBias;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
|
||||
using ElementC = typename Policy::Operator::ElementC;
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA =
|
||||
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpLoadedFragmentScaleBias =
|
||||
typename WarpIteratorScaleBias::Fragment;
|
||||
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand scale vector to shared memory
|
||||
SmemIteratorScaleBias smem_iterator_A_scale_bias_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
ImplicitGemmFpropFusionMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_A_scale_bias_(shared_storage.operand_A_scale_bias_ref(),
|
||||
thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_A_scale_bias_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(IteratorA &iterator_A,
|
||||
IteratorScaleBias &iterator_A_scale_bias,
|
||||
IteratorB &iterator_B, int group_start_A = 0,
|
||||
int group_start_B = 0) {
|
||||
iterator_A.set_iteration_index(group_start_A);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
||||
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
// Uses nan fill for out of bound data
|
||||
cutlass::arch::cp_async_nan<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
// Async Copy for operand A scale and bias vector. Scale and bias vectors
|
||||
// are small. One iteration is enough.
|
||||
if (group_start_A == 0) {
|
||||
typename IteratorScaleBias::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorScaleBias::AccessType *>(
|
||||
this->smem_iterator_A_scale_bias_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorScaleBias::Element>::value *
|
||||
IteratorScaleBias::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpScaleBias>(
|
||||
dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid());
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B);
|
||||
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< iterator over scale and bias vectors in global memory
|
||||
IteratorScaleBias iterator_A_scale_bias,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum,
|
||||
///< Imaginary strides used for planar-complex only - ignored here
|
||||
int64_t imag_stride_A = 0,
|
||||
int64_t imag_stride_B = 0) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
// Uses Nan fill for out of bound data
|
||||
cutlass::arch::cp_async_nan<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
// Async Copy for operand A scale and bias vectors. Scale and bias
|
||||
// vectors are small. One iteration is enough.
|
||||
{
|
||||
typename IteratorScaleBias::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorScaleBias::AccessType *>(
|
||||
this->smem_iterator_A_scale_bias_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorScaleBias::Element>::value *
|
||||
IteratorScaleBias::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpScaleBias>(
|
||||
dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid());
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.advance();
|
||||
iterator_A_scale_bias.advance();
|
||||
iterator_B.advance();
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A[2];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B[2];
|
||||
WarpLoadedFragmentScaleBias warp_loaded_frag_A_scale_bias[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B[2];
|
||||
|
||||
Operator warp_mma;
|
||||
cutlass::conv::warp::FpropScaleBiasReluTransform<WarpTransformedFragmentA,
|
||||
WarpLoadedFragmentScaleBias>
|
||||
elementwise_transform;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_A_scale_bias_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
|
||||
this->warp_tile_iterator_A_scale_bias_.load(
|
||||
warp_loaded_frag_A_scale_bias[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_A_scale_bias_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B[0]);
|
||||
|
||||
elementwise_transform(warp_transformed_frag_A[0],
|
||||
warp_loaded_frag_A_scale_bias[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_scale_bias_.set_kgroup_index(
|
||||
(warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_A_scale_bias_.load(
|
||||
warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_A_scale_bias_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (warp_mma_k > 0) {
|
||||
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B[warp_mma_k % 2]);
|
||||
|
||||
elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_A_scale_bias[warp_mma_k % 2]);
|
||||
}
|
||||
|
||||
warp_mma(
|
||||
accum,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
group_start_iteration_A = 0;
|
||||
group_start_iteration_B = 0;
|
||||
} else {
|
||||
group_start_iteration_A =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B,
|
||||
group_start_iteration_A,
|
||||
group_start_iteration_B);
|
||||
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
elementwise_transform(
|
||||
warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.advance();
|
||||
iterator_A_scale_bias.advance();
|
||||
iterator_B.advance();
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_A_scale_bias_.add_tile_offset(
|
||||
{0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_A_scale_bias_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -195,7 +195,8 @@ public:
|
||||
IteratorA &iterator_A, IteratorB &iterator_B,
|
||||
int group_start_A = 0, int group_start_B = 0) {
|
||||
|
||||
iterator_A.set_iteration_index(group_start_A);
|
||||
iterator_A.set_iteration_index(group_start_A *
|
||||
IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
@ -208,18 +209,23 @@ public:
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B);
|
||||
iterator_B.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
@ -232,12 +238,16 @@ public:
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
@ -279,14 +289,19 @@ public:
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
@ -300,14 +315,19 @@ public:
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
@ -356,6 +376,20 @@ public:
|
||||
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B[0]);
|
||||
|
||||
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
|
||||
// accumulator and this temporary accumulator is added to the final
|
||||
// accumulator once in every mainloop iteration.
|
||||
plus<FragmentC> plus_accum;
|
||||
|
||||
FragmentC tmp_accum;
|
||||
|
||||
if (platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value
|
||||
|| platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
tmp_accum.clear();
|
||||
}
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
@ -406,12 +440,29 @@ public:
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A,
|
||||
group_start_iteration_B);
|
||||
|
||||
warp_mma(
|
||||
accum,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
if (platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value
|
||||
|| platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
warp_mma(
|
||||
tmp_accum,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
tmp_accum
|
||||
);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
accum = plus_accum(accum, tmp_accum);
|
||||
tmp_accum.clear();
|
||||
}
|
||||
} else {
|
||||
warp_mma(
|
||||
accum,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations)
|
||||
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
@ -463,6 +514,13 @@ public:
|
||||
|
||||
}
|
||||
|
||||
if (platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddFastF32>::value
|
||||
|| platform::is_same<typename Operator::MathOperator,
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
accum = plus_accum(accum, tmp_accum);
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
|
||||
@ -0,0 +1,718 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a multistage threadblock-scoped fused activation's scale+bias+relu and
|
||||
Implicit GEMM Convolution kernel.
|
||||
|
||||
The original implicit gemm will store out-of-bound data as zeroes in the
|
||||
shared memory because zeros into the tensor core, zeroes out of the tensor
|
||||
cores. The result is remained the same. When fusing scale+bias+relu
|
||||
into the mainloop, it is no longer true because
|
||||
|
||||
0 x scale + bias = bias
|
||||
|
||||
which is no longer always 0. So, instead of storing zeroes, this fused
|
||||
kernel stores the out-of-bound data as a special NaN (0x7eff), when applying
|
||||
scale+bias+relu, the code is like
|
||||
|
||||
if (data == 0x7eff)
|
||||
data = 0;
|
||||
else
|
||||
data = scale+bias+relu(data, scale, bias);
|
||||
|
||||
The biggest difference compared with the fused Fprop and scale+bias+relu is
|
||||
that scale and bias are loop invariant in Wgrad so that they only needs to
|
||||
be loaded once before the mainloop.
|
||||
|
||||
See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the
|
||||
elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill.
|
||||
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h"
|
||||
#include "cutlass/conv/warp/scale_bias_relu_transform.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Element type of scale and bias vectors
|
||||
typename ElementScaleBias_,
|
||||
/// Layout of scale and bias vectors
|
||||
typename LayoutScaleBias_,
|
||||
/// Element type of scale and bias vectors
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaWgradFusionBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Element type of scale and bias vectors
|
||||
using ElementScaleBias = ElementScaleBias_;
|
||||
|
||||
/// Layout of scale and bias vectors
|
||||
using LayoutScaleBias = LayoutScaleBias_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = cutlass::gemm::GemmShape<Shape::kM / WarpGemm::kM,
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages +
|
||||
Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB =
|
||||
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() {
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
MmaWgradFusionBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorScaleBias_,
|
||||
/// Iterates over vectors of scale and bias vector i
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class ImplicitGemmWgradFusionMultistage
|
||||
: public MmaWgradFusionBase<Shape_, typename IteratorScaleBias_::Element,
|
||||
typename IteratorScaleBias_::Layout, Policy_, Stages> {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorScaleBias = IteratorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
///< Base class
|
||||
using Base = MmaWgradFusionBase<Shape_, typename IteratorScaleBias::Element,
|
||||
typename IteratorScaleBias::Layout, Policy_, Stages>;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
|
||||
using ElementC = typename Policy::Operator::ElementC;
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA =
|
||||
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
static int const kBBufferSize =
|
||||
((sizeof(typename Operator::ElementC) == 4) &&
|
||||
((platform::is_same<typename Operator::Policy::Operator::ElementA,
|
||||
typename Operator::ElementA>::value &&
|
||||
platform::is_same<typename Operator::Policy::Operator::ElementB,
|
||||
typename Operator::ElementB>::value)) &&
|
||||
(Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64))
|
||||
? 1
|
||||
: 2;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpLoadedFragmentScaleBias = typename IteratorScaleBias::Fragment;
|
||||
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
int warp_idx_m_;
|
||||
|
||||
int warp_idx_n_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
ImplicitGemmWgradFusionMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM;
|
||||
warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(IteratorA &iterator_A,
|
||||
IteratorB &iterator_B,
|
||||
int group_start_A = 0, int group_start_B = 0) {
|
||||
|
||||
iterator_A.set_iteration_index(group_start_A);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
||||
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B);
|
||||
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
// Uses nan fill for out of bound data
|
||||
cutlass::arch::cp_async_nan<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< iterator over scale and bias vectors in global memory
|
||||
IteratorScaleBias iterator_B_scale_bias,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum,
|
||||
///< Imaginary strides used for planar-complex only - ignored here
|
||||
int64_t imag_stride_A = 0,
|
||||
int64_t imag_stride_B = 0) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
WarpLoadedFragmentScaleBias warp_loaded_frag_B_scale_bias;
|
||||
iterator_B_scale_bias.add_tile_offset({0, warp_idx_n_});
|
||||
iterator_B_scale_bias.load(warp_loaded_frag_B_scale_bias);
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
// Uses Nan fill for out of bound data
|
||||
cutlass::arch::cp_async_nan<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.advance();
|
||||
iterator_B.advance();
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A[Detail::kBBufferSize];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A[Detail::kBBufferSize];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B[2];
|
||||
|
||||
Operator warp_mma;
|
||||
cutlass::conv::warp::WgradScaleBiasReluTransform<WarpTransformedFragmentB,
|
||||
WarpLoadedFragmentScaleBias>
|
||||
elementwise_transform;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance(iterator_A, iterator_B);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B[0]);
|
||||
|
||||
elementwise_transform(warp_transformed_frag_B[0],
|
||||
warp_loaded_frag_B_scale_bias);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
if (Detail::kBBufferSize == 2) {
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (warp_mma_k > 0) {
|
||||
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % Detail::kBBufferSize],
|
||||
warp_loaded_frag_B[warp_mma_k % 2]);
|
||||
|
||||
elementwise_transform(warp_transformed_frag_B[warp_mma_k % 2],
|
||||
warp_loaded_frag_B_scale_bias);
|
||||
}
|
||||
|
||||
warp_mma(
|
||||
accum,
|
||||
warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
if (Detail::kBBufferSize == 1) {
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
}
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize],
|
||||
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize],
|
||||
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
elementwise_transform(
|
||||
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B_scale_bias);
|
||||
}
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
group_start_iteration_A = 0;
|
||||
group_start_iteration_B = 0;
|
||||
} else {
|
||||
group_start_iteration_A =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B,
|
||||
group_start_iteration_A,
|
||||
group_start_iteration_B);
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.advance();
|
||||
iterator_B.advance();
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,393 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT,INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Templates calculating the address and predicates to the load of scale and bias vectors.
|
||||
|
||||
This iterator uses masks to guard out-of-bounds accesses.
|
||||
|
||||
A precomputed "Params" object minimizes the amount of state that must be
|
||||
stored in registers, and integer addition is used to advance the pointer
|
||||
through memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// PredicatedScaleBiasVectorAccessIterator
|
||||
///
|
||||
template <typename ThreadblockShape,
|
||||
typename Element,
|
||||
typename Layout>
|
||||
class PredicatedScaleBiasVectorAccessIterator;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data.
|
||||
///
|
||||
template <typename ThreadblockShape_, typename Element_>
|
||||
class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
||||
Element_,
|
||||
layout::PitchLinear> {
|
||||
public:
|
||||
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::PitchLinear;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
|
||||
using ConstPointer = const Element *;
|
||||
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
||||
|
||||
static int const kElementsPerAccess = 128 / sizeof_bits<Element>::value;
|
||||
static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess;
|
||||
|
||||
using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
||||
|
||||
using Params = PredicatedScaleBiasVectorAccessIteratorParams;
|
||||
|
||||
private:
|
||||
/// Internal pointer type permits fast address arithmetic
|
||||
using BytePointer = char *;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object with precomputed internal state
|
||||
Params const ¶ms_;
|
||||
|
||||
/// Internal pointer to first access of tile
|
||||
BytePointer pointer_;
|
||||
|
||||
/// Size of tensor
|
||||
Conv2dProblemSize problem_size_;
|
||||
|
||||
int filter_c_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
|
||||
TensorCoord thread_offset_;
|
||||
|
||||
public:
|
||||
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
||||
/// and thread ID
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator(
|
||||
/// Precomputed parameters object
|
||||
Params const ¶ms,
|
||||
/// Extent of tensor
|
||||
Conv2dProblemSize const &problem_size,
|
||||
/// Pointer to the start of the scale vector
|
||||
ConstPointer scale_pointer,
|
||||
/// Pointer to the start of the bias vector
|
||||
ConstPointer bias_pointer,
|
||||
/// ID of each participating thread
|
||||
int thread_id,
|
||||
/// Initial offset of threadblock
|
||||
TensorCoord const &threadblock_offset)
|
||||
: params_(params),
|
||||
problem_size_(problem_size),
|
||||
filter_c_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
pointer_ = (thread_id < kThreads)
|
||||
? reinterpret_cast<BytePointer>(
|
||||
const_cast<NonConstPointer>(scale_pointer))
|
||||
: reinterpret_cast<BytePointer>(
|
||||
const_cast<NonConstPointer>(bias_pointer));
|
||||
|
||||
// Per-thread offset in logical coordinates of tensor
|
||||
int thread_base = (thread_id < kThreads) ? 0 : kThreads;
|
||||
|
||||
thread_offset_ =
|
||||
threadblock_offset +
|
||||
TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0);
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator(
|
||||
/// Precomputed parameters object
|
||||
Params const ¶ms,
|
||||
/// Extent of tensor
|
||||
Conv2dProblemSize const &problem_size,
|
||||
/// Pointer to start of scale vector
|
||||
ConstPointer scale_pointer,
|
||||
/// Pointer to start of scale vector
|
||||
ConstPointer bias_pointer,
|
||||
///< ID of each participating thread
|
||||
int thread_id)
|
||||
: PredicatedScaleBiasVectorAccessIterator(params, problem_size,
|
||||
scale_pointer, bias_pointer,
|
||||
thread_id, make_Coord(0, 0)) {}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(int index) {}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(
|
||||
TensorCoord const &tile_offset) {
|
||||
thread_offset_ =
|
||||
thread_offset_ +
|
||||
TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0);
|
||||
}
|
||||
|
||||
/// Returns a pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType *>(
|
||||
pointer_ +
|
||||
(thread_offset_.contiguous() * sizeof_bits<Element>::value / 8));
|
||||
}
|
||||
|
||||
/// Increment and return an instance to self.
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator &operator++() {
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increment and return an instance to self.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ == problem_size_.S) {
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
} else {
|
||||
filter_r_ = 0;
|
||||
add_tile_offset(TensorCoord(1, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment and return an instance to self.
|
||||
CUTLASS_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator operator++(int) {
|
||||
PredicatedScaleBiasVectorAccessIterator self(*this);
|
||||
operator++();
|
||||
return self;
|
||||
}
|
||||
|
||||
/// Returns whether access is valid or not
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
uint32_t enabled = 0;
|
||||
|
||||
#if defined(_MSC_VER) || (__CUDACC_VER_MAJOR__ < 11)
|
||||
enabled = threadIdx.x < kThreads * 2;
|
||||
#else
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .u32 tid_reg;\n"
|
||||
" .reg .pred p;\n"
|
||||
" mov.u32 tid_reg, %%tid.x;\n"
|
||||
" setp.lt.u32 p, tid_reg, %1;\n"
|
||||
" selp.u32 %0, 1, 0, p;\n"
|
||||
"}\n" : "+r"(enabled) :"n"(kThreads * 2));
|
||||
#endif
|
||||
|
||||
return ((thread_offset_.contiguous() < problem_size_.C) && enabled);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization of PredicatedTileAccessIterator for row-major data.
|
||||
///
|
||||
/// Satisfies: ForwardTileIteratorConcept |
|
||||
/// ReadableContiguousTileIteratorConcept |
|
||||
/// WriteableContiguousTileIteratorConcept |
|
||||
/// MaskedTileIteratorConcept
|
||||
///
|
||||
template <typename ThreadblockShape_,
|
||||
typename Element_>
|
||||
class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
||||
Element_,
|
||||
layout::RowMajor> {
|
||||
public:
|
||||
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
|
||||
using ConstPointer = const Element *;
|
||||
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
||||
|
||||
using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator<
|
||||
layout::PitchLinearShape<ThreadblockShape::kColumn, ThreadblockShape::kRow>,
|
||||
Element,
|
||||
layout::PitchLinear>;
|
||||
|
||||
using AccessType = typename UnderlyingIterator::AccessType;
|
||||
static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;
|
||||
|
||||
using Params = PredicatedScaleBiasVectorAccessIteratorParams;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Underlying pitch-linear tile iterator
|
||||
UnderlyingIterator iterator_;
|
||||
|
||||
public:
|
||||
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
||||
/// and thread ID
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator(
|
||||
///< Precomputed parameters object
|
||||
Params const ¶ms,
|
||||
///< Extent of tensor
|
||||
Conv2dProblemSize const &problem_size,
|
||||
///< Pointer to the start of the scale vector
|
||||
ConstPointer scale_pointer,
|
||||
///< Pointer to the start of the bias vector
|
||||
ConstPointer bias_pointer,
|
||||
///< ID of each participating thread
|
||||
int thread_id,
|
||||
///< Initial offset of threadblock
|
||||
TensorCoord const &threadblock_offset)
|
||||
: iterator_(params, problem_size, scale_pointer, bias_pointer,
|
||||
thread_id,
|
||||
layout::PitchLinearCoord(threadblock_offset.column(),
|
||||
threadblock_offset.row())) {}
|
||||
|
||||
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator(
|
||||
Params const ¶ms, ///< Precomputed parameters object
|
||||
Conv2dProblemSize const &problem_size, ///< Extent of tensor
|
||||
ConstPointer scale_pointer, ///< Pointer to the start of the scale vector
|
||||
ConstPointer bias_pointer, ///< Pointer to the start of the bias vector
|
||||
int thread_id ///< ID of each participating thread
|
||||
)
|
||||
: PredicatedScaleBiasVectorAccessIterator(params, problem_size,
|
||||
scale_pointer, bias_pointer,
|
||||
thread_id, make_Coord(0, 0)) {}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole
|
||||
/// threadblock tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_tile_offset(TensorCoord const &tile_offset) {
|
||||
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
||||
}
|
||||
|
||||
/// Returns a pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType *get() const {
|
||||
return reinterpret_cast<AccessType *>(iterator_.get());
|
||||
}
|
||||
|
||||
/// Advances to the next tile in memory.
|
||||
///
|
||||
/// The first time this method is called, predicates are updated, and the
|
||||
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
||||
/// Subsequent calls are lightweight and must only update the internal
|
||||
/// pointer.
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator &operator++() {
|
||||
++iterator_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances to the next tile in memory.
|
||||
///
|
||||
/// The first time this method is called, predicates are updated, and the
|
||||
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
||||
/// Subsequent calls are lightweight and must only update the internal
|
||||
/// pointer.
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator operator++(int) {
|
||||
PredicatedScaleBiasVectorAccessIterator self(*this);
|
||||
operator++();
|
||||
return self;
|
||||
}
|
||||
|
||||
/// Increment and return an instance to self.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
iterator_.advance();
|
||||
}
|
||||
|
||||
/// Returns whether access is valid or not
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
return iterator_.valid();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,365 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT,INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Templates calculating the address and predicates to the load of scale and bias vectors.
|
||||
|
||||
This iterator uses masks to guard out-of-bounds accesses.
|
||||
|
||||
A precomputed "Params" object minimizes the amount of state that must be
|
||||
stored in registers, and integer addition is used to advance the pointer
|
||||
through memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// PredicatedScaleBiasVectorIterator
|
||||
///
|
||||
template <typename WarpShape,
|
||||
typename Element,
|
||||
typename Layout>
|
||||
class PredicatedScaleBiasVectorIterator;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization of PredicatedTileIterator for wgrad pitch-linear data.
|
||||
///
|
||||
template <typename WarpShape_, typename Element_>
|
||||
class PredicatedScaleBiasVectorIterator<WarpShape_,
|
||||
Element_,
|
||||
layout::PitchLinear> {
|
||||
public:
|
||||
|
||||
using WarpShape = WarpShape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::PitchLinear;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
|
||||
using ConstPointer = const Element *;
|
||||
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
||||
|
||||
static int const kElementsPerAccess = 1;
|
||||
|
||||
using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
||||
|
||||
static int const kIterations = WarpShape::kContiguous / 8;
|
||||
|
||||
/// Fragment object to be loaded or stored
|
||||
using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>;
|
||||
|
||||
/// Parameters object is precomputed state and is host-constructible
|
||||
using Params = Conv2dWgradActivationIteratorOptimizedParams;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object with precomputed internal state
|
||||
Params const ¶ms_;
|
||||
|
||||
/// Internal pointer to first access of tile
|
||||
ConstPointer scale_pointer_;
|
||||
ConstPointer bias_pointer_;
|
||||
|
||||
/// Size of tensor
|
||||
Conv2dProblemSize problem_size_;
|
||||
|
||||
int32_t thread_offset_;
|
||||
|
||||
// Channel dimension in contiguous dimension stays constant for each gemm_iteration_k
|
||||
int32_t filter_c_[kIterations];
|
||||
|
||||
public:
|
||||
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
||||
/// and thread ID
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorIterator(
|
||||
/// Precomputed parameters object
|
||||
Params const ¶ms,
|
||||
/// Extent of tensor
|
||||
Conv2dProblemSize const &problem_size,
|
||||
/// Pointer to the start of the scale vector
|
||||
ConstPointer scale_pointer,
|
||||
/// Pointer to the start of the bias vector
|
||||
ConstPointer bias_pointer,
|
||||
/// ID of each participating thread
|
||||
int thread_id,
|
||||
/// Initial offset of threadblock
|
||||
TensorCoord const &threadblock_offset)
|
||||
: params_(params),
|
||||
problem_size_(problem_size),
|
||||
scale_pointer_(scale_pointer),
|
||||
bias_pointer_(bias_pointer) {
|
||||
|
||||
thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4;
|
||||
}
|
||||
|
||||
/// Construct a PredicatedTileIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorIterator(
|
||||
/// Precomputed parameters object
|
||||
Params const ¶ms,
|
||||
/// Extent of tensor
|
||||
Conv2dProblemSize const &problem_size,
|
||||
/// Pointer to start of scale vector
|
||||
ConstPointer scale_pointer,
|
||||
/// Pointer to start of scale vector
|
||||
ConstPointer bias_pointer,
|
||||
///< ID of each participating thread
|
||||
int thread_id)
|
||||
: PredicatedScaleBiasVectorIterator(params, problem_size,
|
||||
scale_pointer, bias_pointer,
|
||||
thread_id, make_Coord(0, 0)) {}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole warp tiles
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(
|
||||
TensorCoord const &tile_offset) {
|
||||
|
||||
thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int c = 0; c < kIterations; ++c) {
|
||||
int rsc_offset = thread_offset_ + c * 8;
|
||||
|
||||
int residual, tmp;
|
||||
params_.sc_divmod(tmp, residual, rsc_offset);
|
||||
params_.c_divmod(tmp, filter_c_[c], residual);
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
||||
|
||||
frag.fill(__float2half2_rn(0.0f));
|
||||
__half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag);
|
||||
|
||||
// load scale
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < kIterations; ++c) {
|
||||
|
||||
cutlass::arch::global_load<
|
||||
__half,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[c * 2].x,
|
||||
scale_pointer_ + filter_c_[c],
|
||||
true
|
||||
);
|
||||
}
|
||||
|
||||
// load bias
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < kIterations; ++c) {
|
||||
|
||||
cutlass::arch::global_load<
|
||||
__half,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[c * 2 + 1].x,
|
||||
bias_pointer_ + filter_c_[c],
|
||||
true
|
||||
);
|
||||
}
|
||||
|
||||
// duplicate scale
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < kIterations; ++c) {
|
||||
frag_ptr[c * 2].y = frag_ptr[c * 2].x;
|
||||
}
|
||||
|
||||
// duplicate bias
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < kIterations; ++c) {
|
||||
frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x;
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization of PredicatedTileIterator for row-major data.
|
||||
///
|
||||
/// Satisfies: ForwardTileIteratorConcept |
|
||||
/// ReadableContiguousTileIteratorConcept |
|
||||
/// WriteableContiguousTileIteratorConcept |
|
||||
/// MaskedTileIteratorConcept
|
||||
///
|
||||
template <typename WarpShape_,
|
||||
typename Element_>
|
||||
class PredicatedScaleBiasVectorIterator<WarpShape_,
|
||||
Element_,
|
||||
layout::RowMajor> {
|
||||
public:
|
||||
|
||||
using WarpShape = WarpShape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
|
||||
using ConstPointer = const Element *;
|
||||
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
||||
|
||||
using UnderlyingIterator = PredicatedScaleBiasVectorIterator<
|
||||
layout::PitchLinearShape<WarpShape::kColumn, WarpShape::kRow>,
|
||||
Element,
|
||||
layout::PitchLinear>;
|
||||
|
||||
using AccessType = typename UnderlyingIterator::AccessType;
|
||||
static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;
|
||||
using Fragment = typename UnderlyingIterator::Fragment;
|
||||
|
||||
/// Parameters object is precomputed state and is host-constructible
|
||||
class Params {
|
||||
private:
|
||||
friend PredicatedScaleBiasVectorIterator;
|
||||
|
||||
/// Parameters object
|
||||
typename UnderlyingIterator::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
/// Construct the Params object given a pitch-linear tensor's layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Conv2dProblemSize const &problem_size, Layout const &layout)
|
||||
: params_(problem_size, layout::TensorNHWC(0, 0, 0)){};
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Underlying pitch-linear tile iterator
|
||||
UnderlyingIterator iterator_;
|
||||
|
||||
public:
|
||||
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
||||
/// and thread ID
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorIterator(
|
||||
///< Precomputed parameters object
|
||||
Params const ¶ms,
|
||||
///< Extent of tensor
|
||||
Conv2dProblemSize const &problem_size,
|
||||
///< Pointer to the start of the scale vector
|
||||
ConstPointer scale_pointer,
|
||||
///< Pointer to the start of the bias vector
|
||||
ConstPointer bias_pointer,
|
||||
///< ID of each participating thread
|
||||
int thread_id,
|
||||
///< Initial offset of threadblock
|
||||
TensorCoord const &threadblock_offset)
|
||||
: iterator_(params.params_, problem_size, scale_pointer, bias_pointer,
|
||||
thread_id,
|
||||
layout::PitchLinearCoord(threadblock_offset.column(),
|
||||
threadblock_offset.row())) {}
|
||||
|
||||
/// Construct a PredicatedTileIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorIterator(
|
||||
Params const ¶ms, ///< Precomputed parameters object
|
||||
Conv2dProblemSize const &problem_size, ///< Extent of tensor
|
||||
ConstPointer scale_pointer, ///< Pointer to the start of the scale vector
|
||||
ConstPointer bias_pointer, ///< Pointer to the start of the bias vector
|
||||
int thread_id ///< ID of each participating thread
|
||||
)
|
||||
: PredicatedScaleBiasVectorIterator(params, problem_size,
|
||||
scale_pointer, bias_pointer,
|
||||
thread_id, make_Coord(0, 0)) {}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole
|
||||
/// threadblock tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_tile_offset(TensorCoord const &tile_offset) {
|
||||
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
||||
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
iterator_.load(frag);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,247 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Templates implementing computing the addresses of storing of small
|
||||
scale and bias vectors in the shared memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// RegularScaleBiasVectorAccessIterator
|
||||
///
|
||||
template <typename Shape, typename Element, typename Layout>
|
||||
class RegularScaleBiasVectorAccessIterator;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator specialized for congruous arrangements for TensorOps
|
||||
///
|
||||
///
|
||||
/// Satisfies: ForwardTileIteratorConcept |
|
||||
/// ReadableContiguousTileIteratorConcept |
|
||||
/// WriteableContiguousTileIteratorConcept
|
||||
///
|
||||
template <typename Shape_, typename Element_>
|
||||
class RegularScaleBiasVectorAccessIterator<Shape_, Element_, layout::PitchLinear> {
|
||||
public:
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::PitchLinear;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
|
||||
/// Element type per access
|
||||
static int const kElementsPerAccess = 128 / sizeof_bits<Element>::value;
|
||||
static int const kThreads = Shape::kContiguous / kElementsPerAccess;
|
||||
using AccessType = Array<Element, kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Internal pointer
|
||||
AccessType *pointer_;
|
||||
|
||||
/// Internal byte offset
|
||||
Index byte_offset_;
|
||||
|
||||
public:
|
||||
/// Construct a TileIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
RegularScaleBiasVectorAccessIterator(
|
||||
TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias
|
||||
///< vector
|
||||
int thread_id ///< ID of each participating thread
|
||||
)
|
||||
: byte_offset_(0) {
|
||||
// Per-thread offset in logical coordinates of tensor
|
||||
int thread_offset = thread_id * kElementsPerAccess;
|
||||
|
||||
// initialize pointer
|
||||
pointer_ =
|
||||
reinterpret_cast<AccessType *>(scale_bias_ref.data() + thread_offset);
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(int index) {}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
byte_offset_ += pointer_offset * sizeof(Element);
|
||||
}
|
||||
|
||||
/// Returns a pointer
|
||||
CUTLASS_DEVICE
|
||||
AccessType *get() const {
|
||||
|
||||
char *access_byte_ptr =
|
||||
reinterpret_cast<char *>(pointer_);
|
||||
|
||||
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
|
||||
}
|
||||
|
||||
/// Advances to the next tile in memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
RegularScaleBiasVectorAccessIterator &operator++() { return *this; }
|
||||
|
||||
/// Advances to the next tile in memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
RegularScaleBiasVectorAccessIterator operator++(int) {
|
||||
RegularScaleBiasVectorAccessIterator prev(*this);
|
||||
this->operator++();
|
||||
|
||||
return prev;
|
||||
}
|
||||
|
||||
/// Adds a tile offset in the unit of tile.
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const &coord) {
|
||||
// Multiply by 2 because we store sclae and bias belong to the same stage
|
||||
// next to each other.
|
||||
add_pointer_offset(coord.contiguous() * Shape::kContiguous * 2);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator specialized for row major layouts
|
||||
///
|
||||
///
|
||||
/// Satisfies: ForwardTileIteratorConcept |
|
||||
/// ReadableContiguousTileIteratorConcept |
|
||||
/// WriteableContiguousTileIteratorConcept
|
||||
///
|
||||
template <typename Shape_, typename Element_>
|
||||
class RegularScaleBiasVectorAccessIterator<
|
||||
Shape_, Element_,
|
||||
layout::RowMajor> {
|
||||
public:
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
|
||||
/// Underlying iterator type
|
||||
using UnderlyingIterator = RegularScaleBiasVectorAccessIterator<
|
||||
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
||||
layout::PitchLinear>;
|
||||
|
||||
using AccessType = typename UnderlyingIterator::AccessType;
|
||||
|
||||
private:
|
||||
|
||||
/// Underlying iterator
|
||||
UnderlyingIterator iterator_;
|
||||
|
||||
public:
|
||||
/// Construct a TileIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
RegularScaleBiasVectorAccessIterator(
|
||||
TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias
|
||||
///< vector
|
||||
int thread_id ///< ID of each participating thread
|
||||
)
|
||||
: iterator_({scale_bias_ref.data(), scale_bias_ref.stride()}, thread_id) {
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
iterator_.add_pointer_offset(pointer_offset);
|
||||
}
|
||||
|
||||
/// Returns a pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType *get() const {
|
||||
return reinterpret_cast<AccessType *>(iterator_.get());
|
||||
}
|
||||
|
||||
/// Adds a tile offset
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const &coord) {
|
||||
iterator_.add_tile_offset({coord.column(), coord.row()});
|
||||
}
|
||||
|
||||
/// Advances to the next tile in memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
RegularScaleBiasVectorAccessIterator &operator++() {
|
||||
++iterator_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances to the next tile in memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
RegularScaleBiasVectorAccessIterator operator++(int) {
|
||||
RegularScaleBiasVectorAccessIterator prev(*this);
|
||||
++iterator_;
|
||||
|
||||
return prev;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user