CUTLASS 2.4 (Implicit GEMM convolution) (#147)
CUTLASS 2.4 (Implicit GEMM Convolution) Co-authored-by: Manish Gupta <manigupta@nvidia.com>, Haicheng Wu <haichengw@nvidia.com>, Dustyn Blasig <dblasig@nvidia.com>, Andrew Kerr <akerr@nvidia.com>
This commit is contained in:
11
CHANGELOG.md
11
CHANGELOG.md
@ -1,6 +1,17 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
# CUTLASS 2.x
|
||||
## [2.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.4.0) (2020-11-19)
|
||||
* Implicit GEMM convolution kernels supporting CUDA and Tensor Cores on NVIDIA GPUs
|
||||
* Operators: forward (Fprop), backward data gradient (Dgrad), and backward weight gradient (Wgrad) convolution
|
||||
* Data type: FP32, complex<FP32>, Tensor Float 32 (TF32), BFloat16 (BF16), Float16, Int4, Int8, Int32
|
||||
* Spatial dimensions: 1-D, 2-D, and 3-D
|
||||
* Layout: NHWC, NCxHWx
|
||||
* Implicit GEMM convolution components:
|
||||
* Global memory iterators supporting fprop, dgrad, and wgrad
|
||||
* `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture
|
||||
* `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures
|
||||
* [Documentation](/media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
|
||||
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
|
||||
188
CMakeLists.txt
188
CMakeLists.txt
@ -32,7 +32,7 @@ endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.3.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.4.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
@ -137,7 +137,12 @@ if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
|
||||
endif()
|
||||
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CUTLASS_LIBRARY_DEBUG_POSTFIX ".debug" CACHE STRING "Default postfix value for debug libraries")
|
||||
if (DEFINED CMAKE_DEBUG_POSTFIX)
|
||||
set(CUTLASS_LIBRARY_DEBUG_POSTFIX_INIT ${CMAKE_DEBUG_POSTFIX})
|
||||
else()
|
||||
set(CUTLASS_LIBRARY_DEBUG_POSTFIX_INIT .debug)
|
||||
endif()
|
||||
set(CUTLASS_LIBRARY_DEBUG_POSTFIX ${CUTLASS_LIBRARY_DEBUG_POSTFIX_INIT} CACHE STRING "Default postfix value for debug libraries")
|
||||
|
||||
if(WIN32)
|
||||
# On Windows we link against the shared (DLL) runtime. Change gtest settings to match this.
|
||||
@ -192,7 +197,6 @@ endif()
|
||||
set(CUTLASS_DEBUG_TRACE_LEVEL "0" CACHE STRING "Level of debug tracing to perform.")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_DEBUG_TRACE_LEVEL=${CUTLASS_DEBUG_TRACE_LEVEL})
|
||||
|
||||
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
|
||||
"Enable PTX mma instruction for collective matrix multiply operations.")
|
||||
|
||||
@ -466,21 +470,195 @@ if (CUTLASS_ENABLE_CUBLAS)
|
||||
target_compile_definitions(CUTLASS INTERFACE CUTLASS_ENABLE_CUBLAS=1)
|
||||
endif()
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cuDNN.cmake)
|
||||
|
||||
if (CUTLASS_ENABLE_CUDNN)
|
||||
target_compile_definitions(CUTLASS INTERFACE CUTLASS_ENABLE_CUDNN=1)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
include(CTest)
|
||||
enable_testing()
|
||||
if (NOT TARGET test_all)
|
||||
add_custom_target(test_all)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_INSTALL_TESTS ON CACHE BOOL "Install test executables")
|
||||
set(CUTLASS_TEST_EXECUTION_ENVIRONMENT "" CACHE BOOL "Environment in which to invoke unit test executables")
|
||||
|
||||
set(CMAKE_TEST_INSTALL_PREFIX test CACHE STRING "Test root install location, relative to CMAKE_INSTALL_PREFIX.")
|
||||
set(CUTLASS_TEST_INSTALL_PREFIX ${CMAKE_TEST_INSTALL_PREFIX}/cutlass CACHE STRING "Test root install location, relative to CMAKE_INSTALL_PREFIX.")
|
||||
set(CUTLASS_TEST_INSTALL_BINDIR ${CUTLASS_TEST_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR} CACHE STRING "Test root install location, relative to CMAKE_INSTALL_PREFIX.")
|
||||
set(CUTLASS_TEST_INSTALL_LIBDIR ${CUTLASS_TEST_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} CACHE STRING "Test root install location, relative to CMAKE_INSTALL_PREFIX.")
|
||||
|
||||
install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX})
|
||||
install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR})
|
||||
install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_LIBDIR})
|
||||
install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest)
|
||||
|
||||
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.config.cmake)
|
||||
set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "")
|
||||
|
||||
function(cutlass_add_executable_tests NAME TARGET)
|
||||
#
|
||||
# Generates test rules for `make test`, `make test_all`, and `ctest` invoked from either the
|
||||
# <CMAKE_BINARY_DIR> or the <CMAKE_INSTALL_PREFIX>/<CUTLASS_TEST_INSTALL_PREFIX> after installation.
|
||||
#
|
||||
# NAME: The base name for the test. Can be run with `make <NAME>` or `ctest -R 'c<NAME>'`.
|
||||
# TARGET: The target corresponding to the executable under test.
|
||||
# DISABLE_EXECUTABLE_INSTALL_RULE: An option, if given, that disables creating an install rule for TARGET.
|
||||
# DEPENDS: A list of targets or files on which this test is dependent.
|
||||
# DEPENDEES: A list of targets which should depend on this test.
|
||||
# TEST_COMMAND_OPTIONS: A list of variables (i.e. by reference params) which contain command line arguments
|
||||
# to pass to the test executable. A unique test with suffix _0, _1, ... is generated for each set of
|
||||
# options given. If this option is not used, a single test with no arguments is generated.
|
||||
#
|
||||
|
||||
set(options DISABLE_EXECUTABLE_INSTALL_RULE)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT __DISABLE_EXECUTABLE_INSTALL_RULE AND CUTLASS_INSTALL_TESTS)
|
||||
|
||||
# file(RELATIVE_PATH CMAKE_CURRENT_BINARY_RELATIVE_DIR ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
install(
|
||||
TARGETS ${TARGET}
|
||||
RUNTIME DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
if (NOT __TEST_COMMAND_OPTIONS)
|
||||
set(__TEST_COMMAND_OPTIONS " ")
|
||||
endif()
|
||||
|
||||
list(LENGTH __TEST_COMMAND_OPTIONS CMD_COUNT)
|
||||
set(CMD_IDX 0)
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
add_custom_target(${NAME} DEPENDS ${TARGET} ${__DEPENDS})
|
||||
foreach(DEPENDEE ${__DEPENDEES})
|
||||
add_dependencies(${DEPENDEE} ${NAME})
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
foreach(CMD_OPTIONS ${__TEST_COMMAND_OPTIONS})
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
set(TEST_NAME ${NAME}_${CMD_IDX})
|
||||
else()
|
||||
set(TEST_NAME ${NAME})
|
||||
endif()
|
||||
|
||||
# The following rigmarole is needed to deal with spaces and possible quotes in
|
||||
# command line arguments. The options are passed "by reference" as the actual
|
||||
# variable names holding the real options. We then expand these in a way that
|
||||
# preserves any quotes. Note, they have to be in this order for it to work for
|
||||
# all the use cases below.
|
||||
|
||||
set(CMD_OPTIONS ${${CMD_OPTIONS}})
|
||||
list(JOIN CMD_OPTIONS " " TEST_COMMAND_OPTIONS)
|
||||
separate_arguments(CMD_OPTIONS)
|
||||
|
||||
add_custom_target(
|
||||
${TEST_NAME}
|
||||
COMMAND
|
||||
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
|
||||
DEPENDS
|
||||
${TARGET}
|
||||
)
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
add_dependencies(${NAME} ${TEST_NAME})
|
||||
endif()
|
||||
|
||||
foreach(DEPENDEE ${__DEPENDEES})
|
||||
add_dependencies(${DEPENDEE} ${TEST_NAME})
|
||||
endforeach()
|
||||
|
||||
add_test(
|
||||
NAME c${TEST_NAME}
|
||||
COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
|
||||
)
|
||||
|
||||
if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
# To run the tests from an install package with tests enabled, we need to generate test files
|
||||
# that don't rely on the current directory structure in build.
|
||||
|
||||
set(TEST_NAME c${TEST_NAME})
|
||||
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
|
||||
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})
|
||||
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake" @ONLY)
|
||||
|
||||
file(GENERATE
|
||||
OUTPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake"
|
||||
INPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake"
|
||||
)
|
||||
|
||||
install(
|
||||
FILES "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake"
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/
|
||||
)
|
||||
|
||||
set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "")
|
||||
|
||||
endif()
|
||||
|
||||
math(EXPR CMD_IDX "${CMD_IDX} + 1")
|
||||
|
||||
endforeach()
|
||||
|
||||
endfunction()
|
||||
|
||||
if (CUTLASS_ENABLE_TOOLS)
|
||||
add_subdirectory(tools)
|
||||
if (CUTLASS_ENABLE_PROFILER)
|
||||
add_dependencies(test_all test_profiler)
|
||||
endif()
|
||||
endif()
|
||||
if (CUTLASS_ENABLE_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
add_dependencies(test_all test_examples)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_TESTS)
|
||||
include(CTest)
|
||||
enable_testing()
|
||||
add_subdirectory(test)
|
||||
add_dependencies(test_all test_unit)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/cmake")
|
||||
|
||||
file(WRITE "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "# Generated File\n")
|
||||
foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES})
|
||||
file(APPEND "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "include(${GENERATED_FILE})\n")
|
||||
endforeach()
|
||||
|
||||
install(
|
||||
FILES "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake"
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
#? install(
|
||||
#? FILES ${CMAKE_BINARY_DIR}/CTestTestfile.cmake
|
||||
#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
|
||||
#? )
|
||||
#?
|
||||
#? install(
|
||||
#? DIRECTORY
|
||||
#? ${CMAKE_BINARY_DIR}/tools
|
||||
#? ${CMAKE_BINARY_DIR}/test
|
||||
#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
|
||||
#? FILES_MATCHING PATTERN "CTestTestfile.cmake"
|
||||
#? )
|
||||
|
||||
################################################################################
|
||||
|
||||
install(
|
||||
|
||||
87
README.md
87
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 2.3
|
||||
# CUTLASS 2.4
|
||||
|
||||
_CUTLASS 2.3 - September 2020_
|
||||
_CUTLASS 2.4 - November 2020_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
@ -25,11 +25,22 @@ Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, and Ampere architectures.
|
||||
|
||||
Additionaly, CUTLASS implements high-performance convolution (implicit GEMM).
|
||||
Implicit GEMM is the formulation of a convolution operation as a GEMM. This allows CUTLASS
|
||||
to build convolutions by reusing highly optimized warp-wide GEMM components and below.
|
||||
|
||||
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
See the [functionality listing](media/docs/functionality.md) for the list of operations
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.4
|
||||
CUTLASS 2.4 is a significant update to CUTLASS adding:
|
||||
- 1-D, 2-D, and 3-D convolution targeting Tensor and CUDA cores for NVIDIA Ampere, Turing, and Volta GPU architectures
|
||||
- CUTLASS profiler support for convolution
|
||||
- [Documentation](/media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
- See the [CHANGELOG](CHANGELOG.md) for more details.
|
||||
|
||||
# What's New in CUTLASS 2.3
|
||||
|
||||
CUTLASS 2.3 is a minor update to CUTLASS adding:
|
||||
@ -118,6 +129,7 @@ CUTLASS is described in the following documents and the accompanying
|
||||
- [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [GEMM API](media/docs/gemm_api.md) - describes the CUTLASS GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](media/docs/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
@ -140,7 +152,7 @@ CUTLASS unit tests, examples, and utilities can be build with CMake starting ver
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
```
|
||||
```bash
|
||||
$ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
|
||||
```
|
||||
|
||||
@ -149,7 +161,7 @@ for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, and 8.6. To reduce
|
||||
the architectures to build CUTLASS for by changing the CMake configuration setting
|
||||
`CUTLASS_NVCC_ARCHS`.
|
||||
|
||||
```
|
||||
```bash
|
||||
$ mkdir build && cd build
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA's Ampere Architecture
|
||||
@ -160,7 +172,7 @@ From the `build/` directory, compile and run the CUTLASS unit tests by building
|
||||
The unit tests are organized as several binaries mirroring the top-level namespaces of CUTLASS,
|
||||
and they may be executed in parallel via make's `-j` command line argument.
|
||||
|
||||
```
|
||||
```bash
|
||||
$ make test_unit -j
|
||||
...
|
||||
...
|
||||
@ -191,6 +203,8 @@ include/ # client applications should target this directory
|
||||
|
||||
arch/ # direct exposure of architecture features (including instruction-level GEMMs)
|
||||
|
||||
conv/ # code specialized for convolution
|
||||
|
||||
gemm/ # code specialized for general matrix product computations
|
||||
|
||||
layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory
|
||||
@ -228,6 +242,8 @@ examples/
|
||||
|
||||
08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores
|
||||
|
||||
09_turing_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Turing Tensor Cores
|
||||
|
||||
10_planar_complex/ # example demonstrating planar complex GEMM kernels
|
||||
|
||||
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
|
||||
@ -235,9 +251,12 @@ examples/
|
||||
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
|
||||
|
||||
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
|
||||
|
||||
22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores
|
||||
```
|
||||
|
||||
### Tools
|
||||
|
||||
```
|
||||
tools/
|
||||
library/ # CUTLASS Instance Library - contains instantiations of all supported CUTLASS templates
|
||||
@ -266,14 +285,14 @@ Instructions for building and running the Unit tests are described in the [Quick
|
||||
The `tools/profiler/` directory contains a command-line utility for launching each of the GEMM kernels.
|
||||
It can be built as follows:
|
||||
|
||||
```
|
||||
```bash
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
By default, only one tile size is instantiated for each data type, math instruction, and layout.
|
||||
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
|
||||
Beware, this results in *thousands* of kernels and long build times.
|
||||
```
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
@ -282,7 +301,7 @@ $ make cutlass_profiler -j16
|
||||
To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with
|
||||
wildcard characters may be reduce the set of kernels. The following builds exactly one kernel:
|
||||
|
||||
```
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sgemm_128x128_8x2_nn_align1
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
@ -318,6 +337,56 @@ $ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
|
||||
Math: 17218.4 GFLOP/s
|
||||
```
|
||||
|
||||
To compile strictly 2-D or 3-D convolution kernels, filter by operation
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_OPERATIONS=conv2d,conv3d
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
or by name
|
||||
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_LIBRARY_KERNELS=sfprop,s16816fprop,s16816dgrad,s16816wgrad
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling 2-D convolution kernels is as follows:
|
||||
|
||||
```bash
|
||||
$ ./tools/profiler/cutlass_profiler --kernels=cutlass_simt_sfprop_optimized_128x128_8x2_nhwc --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3
|
||||
|
||||
|
||||
=============================
|
||||
Problem ID: 1
|
||||
|
||||
Provider: CUTLASS
|
||||
OperationKind: conv2d
|
||||
Operation: cutlass_simt_sfprop_optimized_128x128_8x2_nhwc
|
||||
|
||||
Status: Success
|
||||
Verification: ON
|
||||
Disposition: Passed
|
||||
|
||||
reference_device: Passed
|
||||
|
||||
Arguments: --conv_kind=fprop --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 --p=224 --q=224 --pad_h=1 --pad_w=1 \
|
||||
--stride_h=1 --stride_w=1 --dilation_h=1 --dilation_w=1 --Activation=f32:nhwc --Filter=f32:nhwc --Output=f32:nhwc \
|
||||
--conv_mode=cross --iterator_algorithm=optimized --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1 \
|
||||
--eq_gemm_provider=none --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \
|
||||
--warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
|
||||
|
||||
Bytes: 2055798784 bytes
|
||||
FLOPs: 118482796544 flops
|
||||
|
||||
Runtime: 8.13237 ms
|
||||
Memory: 235.431 GiB/s
|
||||
|
||||
Math: 14569.3 GFLOP/s
|
||||
|
||||
```
|
||||
|
||||
[Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
|
||||
|
||||
|
||||
|
||||
19
cmake/CTestTestfile.config.cmake
Normal file
19
cmake/CTestTestfile.config.cmake
Normal file
@ -0,0 +1,19 @@
|
||||
# Generated file
|
||||
|
||||
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
else()
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
|
||||
endif()
|
||||
|
||||
if (NOT "@TEST_EXE_DIR@" STREQUAL "")
|
||||
set(TEST_EXE_PATH @TEST_EXE_DIR@/@TEST_EXE@)
|
||||
else()
|
||||
set(TEST_EXE_PATH @TEST_EXE@)
|
||||
endif()
|
||||
|
||||
add_test("@TEST_NAME@" ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
|
||||
if (NOT "@TEST_EXE_WORKING_DIRECTORY@" STREQUAL "")
|
||||
set_tests_properties("@TEST_NAME@" PROPERTIES WORKING_DIRECTORY "@TEST_EXE_WORKING_DIRECTORY@")
|
||||
endif()
|
||||
21
cuBLAS.cmake
21
cuBLAS.cmake
@ -1,3 +1,24 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
message(STATUS "Configuring cublas ...")
|
||||
|
||||
|
||||
107
cuDNN.cmake
Normal file
107
cuDNN.cmake
Normal file
@ -0,0 +1,107 @@
|
||||
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(DEFINED CUDNN_ENABLED)
|
||||
set(CUTLASS_ENABLE_CUDNN ${CUDNN_ENABLED} CACHE BOOL "Enable CUTLASS to build with cuDNN library.")
|
||||
endif()
|
||||
|
||||
if(DEFINED CUTLASS_ENABLE_CUDNN AND NOT CUTLASS_ENABLE_CUDNN)
|
||||
return()
|
||||
endif()
|
||||
|
||||
message(STATUS "Configuring cuDNN ...")
|
||||
|
||||
find_path(
|
||||
_CUDNN_INCLUDE_DIR cudnn.h
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/include
|
||||
$ENV{CUDNN_PATH}/include
|
||||
$ENV{CUDA_PATH}/include
|
||||
${CUDNN_PATH}/include
|
||||
/usr/include)
|
||||
|
||||
find_library(
|
||||
_CUDNN_LIBRARY cudnn
|
||||
HINTS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||
$ENV{CUDNN_PATH}/lib64
|
||||
$ENV{CUDNN_PATH}/lib/x64
|
||||
$ENV{CUDNN_PATH}/lib
|
||||
$ENV{CUDA_PATH}/lib64
|
||||
$ENV{CUDA_PATH}/lib/x64
|
||||
$ENV{CUDA_PATH}/lib
|
||||
${CUDNN_PATH}/lib64
|
||||
${CUDNN_PATH}/lib/x64
|
||||
${CUDNN_PATH}/lib
|
||||
/usr/lib/x86_64-linux-gnu
|
||||
/usr/lib)
|
||||
|
||||
if(_CUDNN_INCLUDE_DIR AND _CUDNN_LIBRARY)
|
||||
|
||||
message(STATUS "cuDNN: ${_CUDNN_LIBRARY}")
|
||||
message(STATUS "cuDNN: ${_CUDNN_INCLUDE_DIR}")
|
||||
|
||||
set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "cuDNN not found.")
|
||||
set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Found")
|
||||
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_CUDNN ${CUDNN_FOUND} CACHE BOOL "Enable CUTLASS to build with cuDNN library.")
|
||||
|
||||
if (CUTLASS_ENABLE_CUDNN AND NOT TARGET cudnn)
|
||||
|
||||
set(CUDNN_INCLUDE_DIR ${_CUDNN_INCLUDE_DIR})
|
||||
set(CUDNN_LIBRARY ${_CUDNN_LIBRARY})
|
||||
|
||||
if(WIN32)
|
||||
add_library(cudnn STATIC IMPORTED GLOBAL)
|
||||
else()
|
||||
add_library(cudnn SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cudnn ALIAS cudnn)
|
||||
|
||||
set_property(
|
||||
TARGET cudnn
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CUDNN_LIBRARY})
|
||||
|
||||
target_include_directories(
|
||||
cudnn
|
||||
INTERFACE
|
||||
$<INSTALL_INTERFACE:include>
|
||||
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>)
|
||||
|
||||
endif()
|
||||
|
||||
if(CUTLASS_ENABLE_CUDNN AND NOT CUDNN_FOUND)
|
||||
message(FATAL_ERROR "CUTLASS_ENABLE_CUDNN enabled but cuDNN library could not be found.")
|
||||
endif()
|
||||
|
||||
message(STATUS "Configuring cuDNN ... done.")
|
||||
@ -20,9 +20,15 @@
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
set(TEST_COMMAND_00 RowMajor --extent=16,16)
|
||||
set(TEST_COMMAND_01 "ColumnMajorInterleaved<4>" --extent=32,8 --output-shape=16 --vectorize=4)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
03_visualize_layout
|
||||
visualize_layout.cpp
|
||||
register_layout.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_COMMAND_00
|
||||
TEST_COMMAND_01
|
||||
)
|
||||
|
||||
|
||||
@ -32,6 +32,8 @@
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
#include "options.h"
|
||||
#include "register_layout.h"
|
||||
|
||||
@ -133,6 +135,8 @@ int main(int argc, char const *arg[]) {
|
||||
|
||||
layout_it->second->print_csv(std::cout);
|
||||
|
||||
cudaFree(0); // Ensure CUDA is available.
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -188,31 +188,6 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
|
||||
int run() {
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 75)) {
|
||||
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
// Return 0 so tests are considered passing if run on unsupported platforms.
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int length_m = 5120;
|
||||
const int length_n = 4096;
|
||||
const int length_k = 4096;
|
||||
@ -337,18 +312,37 @@ int run() {
|
||||
}
|
||||
|
||||
int main() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
// Returning zero so this test passes when built on older Toolkits.
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 75)) {
|
||||
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
28
examples/09_turing_tensorop_conv2dfprop/CMakeLists.txt
Normal file
28
examples/09_turing_tensorop_conv2dfprop/CMakeLists.txt
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
09_turing_tensorop_conv2dfprop
|
||||
turing_tensorop_conv2dfprop.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,758 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
|
||||
|
||||
This example shows how to run convolution kernels using functions and data structures
|
||||
provided by CUTLASS using tensor cores; which we run on a NVIDIA Turing GPU.
|
||||
|
||||
Writing a single high performance convolution kernel is hard but do-able. Whereas writing
|
||||
high performance kernels at scale which works for multiple problem sizes with good abstractions is
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
|
||||
multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance
|
||||
of GPU easily.
|
||||
|
||||
CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp
|
||||
and thread-block level, they compute on their own tile-size with higher level of tile sizes being
|
||||
composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used
|
||||
to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute
|
||||
threadblock-tile (tile size computed by a threadblock).
|
||||
|
||||
In thie example, we split variable initialization into
|
||||
1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel
|
||||
can view them (logical to physical mapping)
|
||||
2. Setting up computation properties : describes how the above set tensors will be used to compute
|
||||
output of convolution.
|
||||
|
||||
First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along
|
||||
with alpha, beta as the equation for convolution is C = alpha * Conv(A, B) + beta * C. In CUTLASS,
|
||||
the kernels first compute Conv(A, B) and leave the rest of the computation to end of the kernel as
|
||||
alpha * X + beta * C is a simple element-wise operation on X (Conv(A, B)) and C. We call this as
|
||||
epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to
|
||||
ElementComputeEpilogue = float. We want to use MMA instructions on Turing and they support 4-bit
|
||||
signed integer. But int4b_t is not fully supported by Nvidia software stack, so CUTLASS introduces
|
||||
cutlass::int4b_t. We use the data type for elements in input tensor A and B as cutlass::int4b_t. We
|
||||
convey this to CUTLASS kernel by initializing template variables ElementAccumulator (int32_t),
|
||||
ElementComputeEpilogue (float), ElementInputA (cutlass::int4b_t), ElementInputB (cutlass::int4b_t),
|
||||
ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out
|
||||
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
|
||||
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup
|
||||
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
|
||||
variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of
|
||||
elements per vector memory access (32), data type of accumulator (int32_t) and data type of
|
||||
computation of linear combination (alpha * X + beta * C).
|
||||
|
||||
Now that we setup the properties of data, we have to setup properties of computation.
|
||||
|
||||
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x128,
|
||||
64x64x128, 8x8x32 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it
|
||||
internally deduces the amount of threads needed per thread-block, amount of shared memory, storing
|
||||
data in bank-conflict free manner, and ton of other variables required to compose, intialize and
|
||||
launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer
|
||||
from understanding and coding complicated hardware optimizations which can easily go wrong.
|
||||
|
||||
CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines
|
||||
constitute the whole process of loading input data from global memory to shared memory, loading data
|
||||
from shared memory to registers, doing matrix multiplication, store to global memory. The below flow
|
||||
sequence shows a typical mma pipeline.
|
||||
|
||||
tensor in global memory -> registers -> tile in shared memory -> registers -> mma -> registers ->
|
||||
output to global memory
|
||||
|
||||
The problem with single pipeline is, each stage is synchronous which means, each stage has to wait
|
||||
until the previous finished executing. There are stages in the pipeline which do not have fixed
|
||||
latency, for example, the loads from global memory and shared memory. Therefore, we can add one more
|
||||
pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads.
|
||||
Finally, the pipeline in a kernel looks like
|
||||
|
||||
(1) tensor in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5)
|
||||
mma -> (6) registers -> (7) output to global memory (1) <null> -> (2) <null> -> (3) tensor in global
|
||||
memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers ->
|
||||
(9) output to global memory
|
||||
|
||||
This way, you can hide the second global memory load latency by doing computation on already loaded
|
||||
input data.
|
||||
|
||||
There are few more template variables initialized such as, which threadblock tile of output matrix
|
||||
is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on.
|
||||
|
||||
These are all put together to create a template variable which describes CUTLASS Implicit GEMM
|
||||
kernel using cutlass::conv::device::ImplicitGemm template.
|
||||
|
||||
The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it.
|
||||
We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come
|
||||
in the way of learning CUTLASS.
|
||||
|
||||
Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS
|
||||
kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64,
|
||||
R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the
|
||||
important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space
|
||||
memory required by the kernel we instantiated. If yes, we create it and pass it along with other
|
||||
arguments created to intialize CUTLASS kernel then, the kernel is launched.
|
||||
|
||||
In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to
|
||||
compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = int32_t; // Data type of accumulator
|
||||
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = cutlass::int4b_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::int4b_t; // Data type of elements in input tensor
|
||||
using ElementOutput = cutlass::int4b_t; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm75;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 2;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationClamp<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
8, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::Tensor4DCoord input_size;
|
||||
cutlass::Tensor4DCoord filter_size;
|
||||
cutlass::Tensor4DCoord padding;
|
||||
cutlass::MatrixCoord conv_stride;
|
||||
cutlass::MatrixCoord dilation;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool benchmark;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
input_size(1, 32, 32, 32),
|
||||
filter_size(32, 3, 3, 32),
|
||||
padding(1, 1, 1, 1),
|
||||
conv_stride(1, 1),
|
||||
dilation(1, 1),
|
||||
reference_check(false),
|
||||
measure_performance(true),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of int4b_t elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 32 elements.
|
||||
//
|
||||
int const kAlignment = 32;
|
||||
|
||||
if ((input_size.c() % kAlignment) ||
|
||||
(filter_size.n() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding.h() != filter_size.h() / 2) ||
|
||||
(padding.w() != filter_size.w() / 2)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor4DCoord input_size,
|
||||
cutlass::Tensor4DCoord filter_size) {
|
||||
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
|
||||
padding.n() = filter_size.h() / 2;
|
||||
padding.h() = filter_size.h() / 2;
|
||||
padding.w() = filter_size.w() / 2;
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
filter_size.c() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
if (filter_size.h() == 3 && filter_size.w() == 3) {
|
||||
padding = {1, 1, 1, 1};
|
||||
}
|
||||
else {
|
||||
filter_size.h() = 1;
|
||||
filter_size.w() = 1;
|
||||
padding = {0, 0, 0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "09_turing_tensorop_conv2dfprop example\n\n"
|
||||
<< " This example uses Turing's Tensor Core operators on int4 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n <int> Input tensor extent N\n"
|
||||
<< " --h <int> Input tensor extent H\n"
|
||||
<< " --w <int> Input tensor extent W\n"
|
||||
<< " --c <int> Input tensor extent C\n"
|
||||
<< " --k <int> Filter extent K\n"
|
||||
<< " --r <int> Filter extent R\n"
|
||||
<< " --s <int> Filter extent S\n\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
<< "$ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
input_size.n(),
|
||||
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
||||
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.h() << ","
|
||||
<< options.input_size.w() << ","
|
||||
<< options.input_size.c() << ","
|
||||
<< options.filter_size.n() << ","
|
||||
<< options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_c(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(7),
|
||||
ElementInputA(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(7),
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
|
||||
// Fill tensor C for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_ref_c.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
// mode (kCrossCorrelation or kConvolution)
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices);
|
||||
|
||||
typename ImplicitGemm::Arguments arguments{
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
ImplicitGemm implicit_gemm_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on host...\n";
|
||||
|
||||
// Compute with reference implementation
|
||||
cutlass::reference::host::Conv2dFprop<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverterClamp<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
tensor_a.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_c.host_ref(),
|
||||
options.alpha,
|
||||
options.beta
|
||||
);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_c.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_c.host_view(),
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
if (options.save_workspace) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "09_tensor_conv_workspace_conv2dfprop_"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
<< "Input = \n" << tensor_a.host_view() << "\n\n"
|
||||
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 7 || (props.major == 7 && props.minor >= 5))) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.benchmark) {
|
||||
// Benchmark several layers
|
||||
|
||||
int batch_sizes[] = {1, 32, 64, 128, 256, 512};
|
||||
|
||||
struct Benchmark {
|
||||
int h, w, c, k, r, s;
|
||||
} layers[] = {
|
||||
{56, 56, 64, 256, 1, 1},
|
||||
{56, 56, 64, 64, 1, 1},
|
||||
{56, 56, 64, 64, 3, 3},
|
||||
{56, 56, 256, 64, 1, 1},
|
||||
{56, 56, 256, 512, 1, 1},
|
||||
{56, 56, 256, 128, 1, 1},
|
||||
{28, 28, 128, 128, 3, 3},
|
||||
{28, 28, 128, 512, 1, 1},
|
||||
{28, 28, 512, 128, 1, 1},
|
||||
{28, 28, 512, 1024, 1, 1},
|
||||
{28, 28, 512, 256, 1, 1},
|
||||
{14, 14, 256, 256, 3, 3},
|
||||
{14, 14, 256, 1024, 1, 1},
|
||||
{14, 14, 1024, 256, 1, 1},
|
||||
{14, 14, 1024, 2048, 1, 1},
|
||||
{14, 14, 1024, 512, 1, 1},
|
||||
{7, 7, 512, 512, 3, 3},
|
||||
};
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
|
||||
int idx = 1;
|
||||
|
||||
for (auto const &layer : layers) {
|
||||
for (auto N : batch_sizes) {
|
||||
|
||||
options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c});
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
result.print(std::cout, idx, options) << std::endl;
|
||||
}
|
||||
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -106,21 +106,6 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
|
||||
int run() {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int length_m = 5120;
|
||||
const int length_n = 4096;
|
||||
const int length_k = 4096;
|
||||
@ -265,17 +250,36 @@ int run() {
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -55,22 +55,6 @@ Performance:
|
||||
|
||||
int run() {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||
run_nonfused_gemm_s8_sm80();
|
||||
run_fused_gemm_s8_sm80();
|
||||
@ -85,17 +69,38 @@ int run() {
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -335,7 +335,7 @@ struct B2bGemm {
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k());
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
|
||||
@ -113,31 +113,6 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
|
||||
int run() {
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
|
||||
// Return 0 so tests are considered passing if run on unsupported platforms.
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int length_m = 5120;
|
||||
const int length_n = 4096;
|
||||
const int length_k = 4096;
|
||||
@ -262,17 +237,36 @@ int run() {
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
// Returning zero so this test passes when built on older Toolkits.
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,7 +71,7 @@ using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<256, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256
|
||||
cutlass::gemm::GemmShape<128, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256
|
||||
// This code section describes the size of MMA op
|
||||
@ -123,31 +123,6 @@ constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
|
||||
|
||||
int run() {
|
||||
|
||||
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
|
||||
// Return 0 so tests are considered passing if run on unsupported platforms.
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int length_m = 512;
|
||||
const int length_n = 512;
|
||||
const int length_k = 1024;
|
||||
@ -295,17 +270,37 @@ int run() {
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
|
||||
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
// Returning zero so this test passes when built on older Toolkits.
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
||||
28
examples/22_ampere_tensorop_conv2dfprop/CMakeLists.txt
Normal file
28
examples/22_ampere_tensorop_conv2dfprop/CMakeLists.txt
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
22_ampere_tensorop_conv2dfprop
|
||||
ampere_tensorop_conv2dfprop.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,763 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
|
||||
This example shows how to run convolution kernels using functions and data structures
|
||||
provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU.
|
||||
|
||||
Writing a single high performance convolution kernel is hard but do-able. Whereas writing
|
||||
high performance kernels at scale which works for multiple problem sizes with good abstractions is
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
|
||||
multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance
|
||||
of GPU easily.
|
||||
|
||||
CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp
|
||||
and thread-block level, they compute on their own tile-size with higher level of tile sizes being
|
||||
composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used
|
||||
to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute
|
||||
threadblock-tile (tile size computed by a threadblock).
|
||||
|
||||
In thie example, we split variable initialization into
|
||||
1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel
|
||||
can view them (logical to physical mapping)
|
||||
2. Setting up computation properties : describes how the above set tensors will be used to compute
|
||||
output of convolution.
|
||||
|
||||
First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along
|
||||
with alpha, beta as the equation for convolution is C = alpha * Conv2dFprop(A, B) + beta * C. In CUTLASS,
|
||||
the kernels first compute Conv2dFprop(A, B) and leave the rest of the computation to end of the kernel as
|
||||
alpha * X + beta * C is a simple element-wise operation on X (Conv2dFprop(A, B)) and C. We call this as
|
||||
epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to
|
||||
ElementComputeEpilogue = float. We use the data type for elements in input tensor A and B as
|
||||
cutlass::half_t. We convey this to CUTLASS kernel by initializing template variables ElementAccumulator (float),
|
||||
ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), ElementInputB (cutlass::half_t),
|
||||
ElementOutput (float). Communicating just the data type is not enough. As the data is laid out
|
||||
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
|
||||
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup
|
||||
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
|
||||
variable EpilogueOp, which takes the data type of output ElementOutput (float), the number of
|
||||
elements per vector memory access (8), data type of accumulator (float) and data type of
|
||||
computation of linear combination (alpha * X + beta * C).
|
||||
|
||||
Now that we setup the properties of data, we have to setup properties of computation.
|
||||
|
||||
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x64,
|
||||
64x64x64, 16x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it
|
||||
internally deduces the amount of threads needed per thread-block, amount of shared memory, storing
|
||||
data in bank-conflict free manner, and ton of other variables required to compose, intialize and
|
||||
launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer
|
||||
from understanding and coding complicated hardware optimizations which can easily go wrong.
|
||||
|
||||
CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines
|
||||
constitute the whole process of loading input data from global memory to shared memory, loading data
|
||||
from shared memory to registers, doing matrix multiplication, store to global memory. The below flow
|
||||
sequence shows a typical mma multistage pipeline.
|
||||
(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h)
|
||||
|
||||
tensor in global memory --cp_async--> tile in shared memory --smem loads--> registers
|
||||
--mma--> registers --global stores--> output to global memory
|
||||
|
||||
NVIDIA Ampere uses `cp_async` to build multistage software pipeline to better hide latencies.
|
||||
|
||||
|
||||
There are few more template variables initialized such as, which threadblock tile of output matrix
|
||||
is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on.
|
||||
|
||||
These are all put together to create a template variable which describes CUTLASS Implicit GEMM
|
||||
kernel using cutlass::conv::device::ImplicitGemm template.
|
||||
|
||||
The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it.
|
||||
We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come
|
||||
in the way of learning CUTLASS.
|
||||
|
||||
Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS
|
||||
kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64,
|
||||
R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the
|
||||
important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space
|
||||
memory required by the kernel we instantiated. If yes, we create it and pass it along with other
|
||||
arguments created to intialize CUTLASS kernel then, the kernel is launched.
|
||||
|
||||
In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to
|
||||
compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementOutput = float; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
// This code section describe iterator algorithm selected is Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kAnalytic;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
IteratorAlgorithm
|
||||
>::Kernel;
|
||||
|
||||
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::Tensor4DCoord input_size;
|
||||
cutlass::Tensor4DCoord filter_size;
|
||||
cutlass::Tensor4DCoord padding;
|
||||
cutlass::MatrixCoord conv_stride;
|
||||
cutlass::MatrixCoord dilation;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool benchmark;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
input_size(1, 32, 32, 32),
|
||||
filter_size(32, 3, 3, 32),
|
||||
padding(1, 1, 1, 1),
|
||||
conv_stride(1, 1),
|
||||
dilation(1, 1),
|
||||
reference_check(false),
|
||||
measure_performance(true),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 8;
|
||||
|
||||
if ((input_size.c() % kAlignment) ||
|
||||
(filter_size.n() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding.h() != filter_size.h() / 2) ||
|
||||
(padding.w() != filter_size.w() / 2)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor4DCoord input_size,
|
||||
cutlass::Tensor4DCoord filter_size) {
|
||||
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
|
||||
padding.n() = filter_size.h() / 2;
|
||||
padding.h() = filter_size.h() / 2;
|
||||
padding.w() = filter_size.w() / 2;
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
filter_size.c() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
if (filter_size.h() == 3 && filter_size.w() == 3) {
|
||||
padding = {1, 1, 1, 1};
|
||||
}
|
||||
else {
|
||||
filter_size.h() = 1;
|
||||
filter_size.w() = 1;
|
||||
padding = {0, 0, 0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "22_ampere_tensorop_conv2dfprop example\n\n"
|
||||
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n <int> Input tensor extent N\n"
|
||||
<< " --h <int> Input tensor extent H\n"
|
||||
<< " --w <int> Input tensor extent W\n"
|
||||
<< " --c <int> Input tensor extent C\n"
|
||||
<< " --k <int> Filter extent K\n"
|
||||
<< " --r <int> Filter extent R\n"
|
||||
<< " --s <int> Filter extent S\n\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/22_ampere_tensorop_conv2dfprop/22_ampere_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
<< "$ ./examples/22_ampere_tensorop_conv2dfprop/22_ampere_tensorop_conv2dfprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
input_size.n(),
|
||||
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
||||
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.h() << ","
|
||||
<< options.input_size.w() << ","
|
||||
<< options.input_size.c() << ","
|
||||
<< options.filter_size.n() << ","
|
||||
<< options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_c(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(7),
|
||||
ElementInputA(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(7),
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
|
||||
// Fill tensor C for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_ref_c.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
typename ImplicitGemm::Arguments arguments{
|
||||
{
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices
|
||||
},
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
ImplicitGemm implicit_gemm_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on host...\n";
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
mode
|
||||
);
|
||||
|
||||
// Compute with reference implementation
|
||||
cutlass::reference::host::Conv2dFprop<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
tensor_a.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_c.host_ref(),
|
||||
options.alpha,
|
||||
options.beta
|
||||
);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_c.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_c.host_view(),
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
if (options.save_workspace) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "22_ampere_workspace_conv2dfprop_"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
<< "Input = \n" << tensor_a.host_view() << "\n\n"
|
||||
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.benchmark) {
|
||||
// Benchmark several layers
|
||||
|
||||
int batch_sizes[] = {1, 32, 64, 128, 256, 512};
|
||||
|
||||
struct Benchmark {
|
||||
int h, w, c, k, r, s;
|
||||
} layers[] = {
|
||||
{56, 56, 64, 256, 1, 1},
|
||||
{56, 56, 64, 64, 1, 1},
|
||||
{56, 56, 64, 64, 3, 3},
|
||||
{56, 56, 256, 64, 1, 1},
|
||||
{56, 56, 256, 512, 1, 1},
|
||||
{56, 56, 256, 128, 1, 1},
|
||||
{28, 28, 128, 128, 3, 3},
|
||||
{28, 28, 128, 512, 1, 1},
|
||||
{28, 28, 512, 128, 1, 1},
|
||||
{28, 28, 512, 1024, 1, 1},
|
||||
{28, 28, 512, 256, 1, 1},
|
||||
{14, 14, 256, 256, 3, 3},
|
||||
{14, 14, 256, 1024, 1, 1},
|
||||
{14, 14, 1024, 256, 1, 1},
|
||||
{14, 14, 1024, 2048, 1, 1},
|
||||
{14, 14, 1024, 512, 1, 1},
|
||||
{7, 7, 512, 512, 3, 3},
|
||||
};
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
|
||||
int idx = 1;
|
||||
|
||||
for (auto const &layer : layers) {
|
||||
for (auto N : batch_sizes) {
|
||||
|
||||
options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c});
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
result.print(std::cout, idx, options) << std::endl;
|
||||
}
|
||||
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -22,15 +22,20 @@
|
||||
|
||||
set(CUTLASS_EXAMPLES_COMMON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/common)
|
||||
|
||||
add_custom_target(cutlass_examples)
|
||||
add_custom_target(test_examples)
|
||||
|
||||
function(cutlass_example_add_executable NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs)
|
||||
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
add_dependencies(cutlass_examples ${NAME})
|
||||
|
||||
target_link_libraries(
|
||||
${NAME}
|
||||
PRIVATE
|
||||
@ -44,19 +49,21 @@ function(cutlass_example_add_executable NAME)
|
||||
${CUTLASS_EXAMPLES_COMMON_SOURCE_DIR}
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
test_${NAME}
|
||||
COMMAND
|
||||
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${NAME}>
|
||||
DEPENDS
|
||||
${NAME}
|
||||
install(
|
||||
TARGETS ${NAME}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
)
|
||||
|
||||
cutlass_add_executable_tests(
|
||||
test_examples_${NAME} ${NAME}
|
||||
DEPENDS ${__DEPENDS}
|
||||
DEPENDEES test_examples ${__DEPENDEES}
|
||||
TEST_COMMAND_OPTIONS ${__TEST_COMMAND_OPTIONS}
|
||||
DISABLE_EXECUTABLE_INSTALL_RULE
|
||||
)
|
||||
|
||||
endfunction()
|
||||
|
||||
add_custom_target(cutlass_examples)
|
||||
add_custom_target(test_examples)
|
||||
|
||||
foreach(EXAMPLE
|
||||
00_basic_gemm
|
||||
01_cutlass_utilities
|
||||
@ -67,16 +74,16 @@ foreach(EXAMPLE
|
||||
06_splitK_gemm
|
||||
07_volta_tensorop_gemm
|
||||
08_turing_tensorop_gemm
|
||||
09_turing_tensorop_conv2dfprop
|
||||
10_planar_complex
|
||||
11_planar_complex_array
|
||||
12_gemm_bias_relu
|
||||
13_fused_two_gemms
|
||||
14_ampere_tf32_tensorop_gemm
|
||||
15_ampere_sparse_tensorop_gemm
|
||||
22_ampere_tensorop_conv2dfprop
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
add_dependencies(cutlass_examples ${EXAMPLE})
|
||||
add_dependencies(test_examples test_${EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
@ -74,6 +74,10 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct cp_async<SizeInBytes, CacheOperation::Always> {
|
||||
// Make sure the size is supported.
|
||||
static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16),
|
||||
"Size is not supported");
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
@ -104,6 +108,10 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct cp_async_zfill<SizeInBytes, CacheOperation::Always> {
|
||||
// Make sure the size is supported.
|
||||
static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16),
|
||||
"Size is not supported");
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
@ -138,6 +146,10 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct cp_async<SizeInBytes, CacheOperation::Global> {
|
||||
// Make sure the size is supported.
|
||||
static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16),
|
||||
"Size is not supported");
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
@ -171,6 +183,10 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct cp_async_zfill<SizeInBytes, CacheOperation::Global> {
|
||||
// Make sure the size is supported.
|
||||
static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16),
|
||||
"Size is not supported");
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
@ -235,4 +251,3 @@ CUTLASS_DEVICE void cp_async_wait<0>() {
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -201,5 +201,5 @@ struct SparseMma;
|
||||
#include "cutlass/arch/mma_sm70.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm80.h"
|
||||
#include "cutlass/arch/sp_mma_sm80.h"
|
||||
#include "cutlass/arch/mma_sparse_sm80.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -365,7 +365,7 @@ struct Mma<
|
||||
}
|
||||
};
|
||||
|
||||
/// Matrix multiply-add operation: S32 = S8 * U8 + S32
|
||||
/// Matrix multiply-add operation: S32 = U8 * U8 + S32
|
||||
template <>
|
||||
struct Mma<
|
||||
gemm::GemmShape<8, 8, 16>,
|
||||
@ -599,7 +599,7 @@ struct Mma<
|
||||
}
|
||||
};
|
||||
|
||||
/// Matrix multiply-add operation: S32 = S8 * U8 + S32
|
||||
/// Matrix multiply-add operation: S32 = U8 * U8 + S32
|
||||
template <>
|
||||
struct Mma<
|
||||
gemm::GemmShape<8,8,16>,
|
||||
|
||||
@ -29,7 +29,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mma_sm80.h"
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "mma.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -52,7 +52,7 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif //__clang__
|
||||
#endif //!defined(__clang__)
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
|
||||
@ -82,6 +82,12 @@ struct CutlassToWmmaDataType<cutlass::half_t> {
|
||||
using Type = __half;
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||
template<>
|
||||
struct CutlassToWmmaDataType<cutlass::bfloat16_t> {
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
#endif
|
||||
|
||||
/// Statically maps int8_t => char
|
||||
template<>
|
||||
@ -158,6 +164,14 @@ template<>
|
||||
struct WmmaToCutlassDataType<__half> {
|
||||
using Type = cutlass::half_t;
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||
template<>
|
||||
struct WmmaToCutlassDataType<__nv_bfloat16> {
|
||||
using Type = cutlass::bfloat16_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
450
include/cutlass/conv/conv2d_problem_size.h
Normal file
450
include/cutlass/conv/conv2d_problem_size.h
Normal file
@ -0,0 +1,450 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief This file contains definitions and utility functions for describing convolution problem sizes.
|
||||
|
||||
Conv2dProblem desciption:
|
||||
activation (NHWC),
|
||||
filter (KRSC),
|
||||
output (NPQK),
|
||||
pading (pad_h, pad_w),
|
||||
stride (stride_h, stride_w),
|
||||
dilation (dilation_h, dilation_w).
|
||||
|
||||
Free functions to map:
|
||||
Map tensor extents (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator)
|
||||
Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator)
|
||||
Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator)
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Problem size structure
|
||||
struct Conv2dProblemSize {
|
||||
|
||||
// Conv2d strictly problem size parameters
|
||||
int N, H, W, C, P, Q, K, R, S;
|
||||
int pad_h, pad_w;
|
||||
int stride_h, stride_w;
|
||||
int dilation_h, dilation_w;
|
||||
Mode mode;
|
||||
|
||||
// Conv2d implementation-related parameters
|
||||
int split_k_slices;
|
||||
int groups;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dProblemSize():
|
||||
N(0), H(0), W(0), C(0), P(0), Q(0), K(0), R(0), S(0),
|
||||
pad_h(0), pad_w(0), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1),
|
||||
mode(Mode::kConvolution), split_k_slices(1), groups(1) { }
|
||||
|
||||
/// Constructor for default padding, stride, dilation, and split-K
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dProblemSize(
|
||||
int N,
|
||||
int H,
|
||||
int W,
|
||||
int C,
|
||||
int P,
|
||||
int Q,
|
||||
int K,
|
||||
int R,
|
||||
int S,
|
||||
Mode mode
|
||||
):
|
||||
N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S),
|
||||
pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1),
|
||||
mode(mode), split_k_slices(1), groups (1) { }
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dProblemSize(
|
||||
int N,
|
||||
int H,
|
||||
int W,
|
||||
int C,
|
||||
int K,
|
||||
int R,
|
||||
int S,
|
||||
int P,
|
||||
int Q,
|
||||
int pad_h,
|
||||
int pad_w,
|
||||
int stride_h,
|
||||
int stride_w,
|
||||
int dilation_h,
|
||||
int dilation_w,
|
||||
Mode mode,
|
||||
int split_k_slices = 1,
|
||||
int groups = 1
|
||||
):
|
||||
N(N), H(H), W(W), C(C), K(K), R(R), S(S), P(P), Q(Q),
|
||||
pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w),
|
||||
dilation_h(dilation_h), dilation_w(dilation_w),
|
||||
mode(mode), split_k_slices(split_k_slices), groups (groups) { }
|
||||
|
||||
/// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord
|
||||
// set user-defined output size and sets P and Q (include all data members in ctor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dProblemSize(
|
||||
cutlass::Tensor4DCoord input_size, // NHWC
|
||||
cutlass::Tensor4DCoord filter_size, // KRSC
|
||||
cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _
|
||||
cutlass::MatrixCoord stride, // stride_h, stride_w
|
||||
cutlass::MatrixCoord dilation, // dilation_h, dilation_w
|
||||
cutlass::Tensor4DCoord output_size, // NPQK
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
||||
int split_k_slices = 1,
|
||||
int groups = 1
|
||||
):
|
||||
N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()),
|
||||
K(filter_size.n()), R(filter_size.h()), S(filter_size.w()),
|
||||
pad_h(padding[0]), pad_w(padding[2]),
|
||||
stride_h(stride.row()), stride_w(stride.column()),
|
||||
dilation_h(dilation.row()), dilation_w(dilation.column()),
|
||||
P(output_size.h()), Q(output_size.w()),
|
||||
mode(mode), split_k_slices(split_k_slices), groups(groups) {}
|
||||
|
||||
/// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord
|
||||
// computes output size and sets P and Q (skip output from ctor arguments)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dProblemSize(
|
||||
cutlass::Tensor4DCoord input_size, // NHWC
|
||||
cutlass::Tensor4DCoord filter_size, // KRSC
|
||||
cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _
|
||||
cutlass::MatrixCoord stride, // stride_h, stride_w
|
||||
cutlass::MatrixCoord dilation, // dilation_h, dilation_w
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
||||
int split_k_slices = 1,
|
||||
int groups = 1
|
||||
):
|
||||
N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()),
|
||||
K(filter_size.n()), R(filter_size.h()), S(filter_size.w()),
|
||||
pad_h(padding[0]), pad_w(padding[2]),
|
||||
stride_h(stride.row()), stride_w(stride.column()),
|
||||
dilation_h(dilation.row()), dilation_w(dilation.column()),
|
||||
mode(mode), split_k_slices(split_k_slices), groups(groups) {
|
||||
// set output P and Q
|
||||
P = ((H + pad_h * 2 - R * dilation_h) / stride_h) + 1;
|
||||
Q = ((W + pad_w * 2 - S * dilation_w) / stride_w) + 1;
|
||||
}
|
||||
|
||||
/// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord
|
||||
// set user-defined output size and sets P and Q (skip padding, striding, and dilation)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dProblemSize(
|
||||
cutlass::Tensor4DCoord input_size, // NHWC
|
||||
cutlass::Tensor4DCoord filter_size, // KRSC
|
||||
cutlass::Tensor4DCoord output_size, // NPQK
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
||||
int split_k_slices = 1,
|
||||
int groups = 1
|
||||
):
|
||||
N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()),
|
||||
K(filter_size.n()), R(filter_size.h()), S(filter_size.w()),
|
||||
P(output_size.h()), Q(output_size.w()),
|
||||
pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1),
|
||||
dilation_h(1), dilation_w(1),
|
||||
mode(mode), split_k_slices(split_k_slices), groups(groups) {}
|
||||
|
||||
// Reset covolution mode in the problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dProblemSize reset_mode(cutlass::conv::Mode mode_) {
|
||||
Conv2dProblemSize tmp(*this);
|
||||
tmp.mode = mode_;
|
||||
return tmp;
|
||||
}
|
||||
|
||||
// Reset covolution mode in the problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dProblemSize reset_split_k_slices(int split_k_slices_) {
|
||||
Conv2dProblemSize tmp(*this);
|
||||
tmp.split_k_slices = split_k_slices_;
|
||||
return tmp;
|
||||
}
|
||||
|
||||
/// Equality operator (ignores mode and split_k_slice)
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(Conv2dProblemSize const &conv) const {
|
||||
return (
|
||||
(N == conv.N) && (W == conv.H) && (W == conv.W) && (C == conv.C) &&
|
||||
(K == conv.K) && (R == conv.R) && (S == conv.S) &&
|
||||
(P == conv.P) && (Q == conv.Q) &&
|
||||
(pad_h == conv.pad_h) && (pad_w == conv.pad_w) &&
|
||||
(stride_h == conv.stride_h) && (stride_w == conv.stride_w) &&
|
||||
(dilation_h == conv.dilation_h) && (dilation_h == conv.dilation_h)
|
||||
);
|
||||
}
|
||||
|
||||
/// Inequality operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(Conv2dProblemSize const &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
/// Returns activation extent as Tensor4DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor4DCoord activation_extent() const {
|
||||
|
||||
return cutlass::Tensor4DCoord ({N, H, W, C});
|
||||
}
|
||||
|
||||
/// Returns filter extent as Tensor4DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor4DCoord filter_extent() const {
|
||||
|
||||
return cutlass::Tensor4DCoord ({K, R, S, C});
|
||||
}
|
||||
|
||||
/// Returns output extent as Tensor4DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor4DCoord output_extent() const {
|
||||
|
||||
return cutlass::Tensor4DCoord ({N, P, Q, K});
|
||||
}
|
||||
|
||||
/// Returns activation size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t activation_size() const {
|
||||
|
||||
return (N * H * W * C);
|
||||
}
|
||||
|
||||
/// Returns filter size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t filter_size() const {
|
||||
|
||||
return (K * R * S * C);
|
||||
}
|
||||
|
||||
/// Returns output size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t output_size() const {
|
||||
|
||||
return (N * P * Q * K);
|
||||
}
|
||||
|
||||
/// Returns output extent as Tensor4DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor4DCoord padding() const {
|
||||
|
||||
return cutlass::Tensor4DCoord ({pad_h, pad_h, pad_w, pad_w});
|
||||
}
|
||||
|
||||
/// Returns stride as MatrixCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::MatrixCoord stride() const {
|
||||
|
||||
return cutlass::MatrixCoord ({stride_h, stride_w});
|
||||
}
|
||||
|
||||
/// Returns dilation as MatrixCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::MatrixCoord dilation() const {
|
||||
|
||||
return cutlass::MatrixCoord ({dilation_h, dilation_w});
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// ImplicitGemm helper functions //
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Determine the problem size of the implicit GEMM operation
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord implicit_gemm_problem_size(
|
||||
Operator conv_operator,
|
||||
Conv2dProblemSize const &problem_size) {
|
||||
// Compute problem size
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
return gemm::GemmCoord(
|
||||
problem_size.N * problem_size.P * problem_size.Q,
|
||||
problem_size.K,
|
||||
problem_size.R * problem_size.S * problem_size.C
|
||||
);
|
||||
case Operator::kDgrad:
|
||||
return gemm::GemmCoord(
|
||||
problem_size.N * problem_size.H * problem_size.W,
|
||||
problem_size.C,
|
||||
problem_size.R * problem_size.S * problem_size.K
|
||||
);
|
||||
case Operator::kWgrad:
|
||||
return gemm::GemmCoord(
|
||||
problem_size.K,
|
||||
problem_size.R * problem_size.S * problem_size.C,
|
||||
problem_size.N * problem_size.P * problem_size.Q
|
||||
);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return gemm::GemmCoord();
|
||||
}
|
||||
|
||||
// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm
|
||||
CUTLASS_HOST_DEVICE
|
||||
int implicit_gemm_k_iterations(
|
||||
Operator conv_operator,
|
||||
int threadblock_K,
|
||||
Conv2dProblemSize const &problem_size) {
|
||||
|
||||
int iterations = 0;
|
||||
int elements_per_split_k_slice = 0;
|
||||
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
case Operator::kDgrad:
|
||||
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
case Operator::kWgrad:
|
||||
elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return iterations;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Returns ImplicitGemm tensor A extent as Tensor4DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent(
|
||||
Operator conv_operator,
|
||||
Conv2dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.activation_extent();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.output_extent();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.output_extent();
|
||||
default : break;
|
||||
}
|
||||
return cutlass::Tensor4DCoord();
|
||||
}
|
||||
|
||||
/// Returns ImplicitGemm tensor B extent as Tensor4DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent(
|
||||
Operator conv_operator,
|
||||
Conv2dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.filter_extent();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent();
|
||||
default : break;
|
||||
}
|
||||
return cutlass::Tensor4DCoord();
|
||||
}
|
||||
|
||||
/// Returns ImplicitGemm tensor C extent as Tensor4DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent(
|
||||
Operator conv_operator,
|
||||
Conv2dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.output_extent();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent();
|
||||
default : break;
|
||||
}
|
||||
return cutlass::Tensor4DCoord();
|
||||
}
|
||||
|
||||
/// Returns ImplicitGemm tensor A size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t implicit_gemm_tensor_a_size(
|
||||
Operator conv_operator,
|
||||
Conv2dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.activation_size();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.output_size();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.output_size();
|
||||
default : break;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Returns ImplicitGemm tensor B size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t implicit_gemm_tensor_b_size(
|
||||
Operator conv_operator,
|
||||
Conv2dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.filter_size();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.filter_size();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.activation_size();
|
||||
default : break;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Returns ImplicitGemm tensor C size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t implicit_gemm_tensor_c_size(
|
||||
Operator conv_operator,
|
||||
Conv2dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.output_size();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.activation_size();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.filter_size();
|
||||
default : break;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
453
include/cutlass/conv/conv3d_problem_size.h
Normal file
453
include/cutlass/conv/conv3d_problem_size.h
Normal file
@ -0,0 +1,453 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief This file contains definitions and utility functions for describing convolution problem sizes.
|
||||
|
||||
Conv3dProblem desciption:
|
||||
activation (NDHWC),
|
||||
filter (KTRSC),
|
||||
output (NZPQK),
|
||||
pading (pad_d, pad_h, pad_w),
|
||||
stride (stride_d, stride_h, stride_w),
|
||||
dilation (dilation_d, dilation_h, dilation_w).
|
||||
|
||||
Free functions to map:
|
||||
Map tensor extents (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator)
|
||||
Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator)
|
||||
Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator)
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Problem size structure
|
||||
struct Conv3dProblemSize : public Conv2dProblemSize {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
// 3D coordinate for padding, stride, and dilation in (d, h, w) dimensions
|
||||
using Coord3D = Coord<3>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
// Conv3d strictly problem size parameters
|
||||
int D, T, Z; // input depth, filter depth, output depth
|
||||
int pad_d; // padding in depth dimension
|
||||
int stride_d; // stride in depth dimension
|
||||
int dilation_d; // dilation in depth dimension
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dProblemSize():
|
||||
D(0), T(0), Z(0),
|
||||
pad_d(0),
|
||||
stride_d(1),
|
||||
dilation_d(1),
|
||||
Conv2dProblemSize() { }
|
||||
|
||||
/// Constructor for default padding, stride, dilation, and split-K
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dProblemSize(
|
||||
int N,
|
||||
int D,
|
||||
int H,
|
||||
int W,
|
||||
int C,
|
||||
int Z,
|
||||
int P,
|
||||
int Q,
|
||||
int K,
|
||||
int T,
|
||||
int R,
|
||||
int S,
|
||||
Mode mode
|
||||
):
|
||||
D(D), T(T), Z(Z),
|
||||
pad_d(T / 2), stride_d(1), dilation_d(1),
|
||||
Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode) { }
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dProblemSize(
|
||||
int N,
|
||||
int D,
|
||||
int H,
|
||||
int W,
|
||||
int C,
|
||||
int K,
|
||||
int T,
|
||||
int R,
|
||||
int S,
|
||||
int Z,
|
||||
int P,
|
||||
int Q,
|
||||
int pad_d,
|
||||
int pad_h,
|
||||
int pad_w,
|
||||
int stride_d,
|
||||
int stride_h,
|
||||
int stride_w,
|
||||
int dilation_d,
|
||||
int dilation_h,
|
||||
int dilation_w,
|
||||
Mode mode,
|
||||
int split_k_slices = 1,
|
||||
int groups = 1
|
||||
):
|
||||
D(D), T(T), Z(Z),
|
||||
pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d),
|
||||
Conv2dProblemSize(
|
||||
N, H, W, C, K, R, S, P, Q,
|
||||
pad_h, pad_w,
|
||||
stride_h, stride_w,
|
||||
dilation_h, dilation_w,
|
||||
mode, split_k_slices, groups) { }
|
||||
|
||||
/// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D
|
||||
// set *user-defined* output size and sets Z, P, and Q (include all data members in ctor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dProblemSize(
|
||||
cutlass::Tensor5DCoord input_size, // NDHWC
|
||||
cutlass::Tensor5DCoord filter_size, // KTRSC
|
||||
Coord3D padding, // pad_d, pad_h, pad_w
|
||||
Coord3D stride, // stride_d, stride_h, stride_w
|
||||
Coord3D dilation, // dilation_d, dilation_h, dilation_w
|
||||
cutlass::Tensor5DCoord output_size, // NZPQK
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
||||
int split_k_slices = 1,
|
||||
int groups = 1
|
||||
):
|
||||
D(input_size.d()), T(filter_size.d()), Z(output_size.d()),
|
||||
pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]),
|
||||
Conv2dProblemSize(
|
||||
{input_size.n(), input_size.h(), input_size.w(), input_size.c()},
|
||||
{filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
|
||||
{padding[1], padding[1], padding[2], padding[2]},
|
||||
{stride[1], stride[2]},
|
||||
{dilation[1], dilation[2]},
|
||||
{output_size.n(), output_size.h(), output_size.w(), output_size.c()},
|
||||
mode, split_k_slices, groups
|
||||
) { }
|
||||
|
||||
/// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D
|
||||
// *computes* output size and sets Z, P and Q (include all data members in ctor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dProblemSize(
|
||||
cutlass::Tensor5DCoord input_size, // NDHWC
|
||||
cutlass::Tensor5DCoord filter_size, // KTRSC
|
||||
Coord3D padding, // pad_d, pad_h, pad_w
|
||||
Coord3D stride, // stride_d, stride_h, stride_w
|
||||
Coord3D dilation, // dilation_d, dilation_h, dilation_w
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
|
||||
int split_k_slices = 1,
|
||||
int groups = 1
|
||||
):
|
||||
D(input_size.d()), T(filter_size.d()),
|
||||
pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]),
|
||||
Conv2dProblemSize(
|
||||
{input_size.n(), input_size.h(), input_size.w(), input_size.c()},
|
||||
{filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
|
||||
{padding[1], padding[1], padding[2], padding[2]},
|
||||
{stride[1], stride[2]},
|
||||
{dilation[1], dilation[2]},
|
||||
mode, split_k_slices, groups
|
||||
) {
|
||||
// set output Z
|
||||
Z = ((D + pad_d - T * dilation_d) / stride_d) + 1;
|
||||
}
|
||||
|
||||
/// Equality operator (ignores mode and split_k_slice)
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(Conv3dProblemSize const &conv) const {
|
||||
return (
|
||||
(N == conv.N) && (D == conv.D) && (H == conv.H) && (W == conv.W) && (C == conv.C) &&
|
||||
(K == conv.K) && (T == conv.T) && (R == conv.R) && (S == conv.S) &&
|
||||
(Z == conv.Z) &&(P == conv.P) && (Q == conv.Q) &&
|
||||
(pad_d == conv.pad_d) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) &&
|
||||
(stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_h) &&
|
||||
(dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_h == conv.dilation_h)
|
||||
);
|
||||
}
|
||||
|
||||
/// Inequality operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(Conv3dProblemSize const &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
// Reset covolution mode in the problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dProblemSize reset_mode(cutlass::conv::Mode mode_) {
|
||||
Conv3dProblemSize tmp(*this);
|
||||
tmp.mode = mode_;
|
||||
return tmp;
|
||||
}
|
||||
|
||||
// Reset covolution mode in the problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dProblemSize reset_split_k_slices(int split_k_slices_) {
|
||||
Conv3dProblemSize tmp(*this);
|
||||
tmp.split_k_slices = split_k_slices_;
|
||||
return tmp;
|
||||
}
|
||||
|
||||
/// Returns activation extent as Tensor5DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor5DCoord activation_extent() const {
|
||||
|
||||
return cutlass::Tensor5DCoord ({N, D, H, W, C});
|
||||
}
|
||||
|
||||
/// Returns filter extent as Tensor5DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor5DCoord filter_extent() const {
|
||||
|
||||
return cutlass::Tensor5DCoord ({K, T, R, S, C});
|
||||
}
|
||||
|
||||
/// Returns output extent as Tensor5DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor5DCoord output_extent() const {
|
||||
|
||||
return cutlass::Tensor5DCoord ({N, Z, P, Q, K});
|
||||
}
|
||||
|
||||
/// Returns activation size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t activation_size() const {
|
||||
|
||||
return (N * D * H * W * C);
|
||||
}
|
||||
|
||||
/// Returns filter size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t filter_size() const {
|
||||
|
||||
return (K * T * R * S * C);
|
||||
}
|
||||
|
||||
/// Returns output size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t output_size() const {
|
||||
|
||||
return (N * Z * P * Q * K);
|
||||
}
|
||||
|
||||
/// Returns output extent as Tensor5DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord3D padding() const {
|
||||
|
||||
return Coord3D ({pad_d, pad_h, pad_w});
|
||||
}
|
||||
|
||||
/// Returns stride as MatrixCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord3D stride() const {
|
||||
|
||||
return Coord3D ({stride_d, stride_h, stride_w});
|
||||
}
|
||||
|
||||
/// Returns dilation as MatrixCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord3D dilation() const {
|
||||
|
||||
return Coord3D ({dilation_d, dilation_h, dilation_w});
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// ImplicitGemm helper functions //
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Determine the problem size of the implicit GEMM operation
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord implicit_gemm_problem_size(
|
||||
Operator conv_operator,
|
||||
Conv3dProblemSize const &problem_size) {
|
||||
// Compute problem size
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
return gemm::GemmCoord(
|
||||
problem_size.N * problem_size.Z * problem_size.P * problem_size.Q,
|
||||
problem_size.K,
|
||||
problem_size.T * problem_size.R * problem_size.S * problem_size.C
|
||||
);
|
||||
case Operator::kDgrad:
|
||||
return gemm::GemmCoord(
|
||||
problem_size.N * problem_size.D * problem_size.H * problem_size.W,
|
||||
problem_size.C,
|
||||
problem_size.T * problem_size.R * problem_size.S * problem_size.K
|
||||
);
|
||||
case Operator::kWgrad:
|
||||
return gemm::GemmCoord(
|
||||
problem_size.K,
|
||||
problem_size.T * problem_size.R * problem_size.S * problem_size.C,
|
||||
problem_size.N * problem_size.Z * problem_size.P * problem_size.Q
|
||||
);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return gemm::GemmCoord();
|
||||
}
|
||||
|
||||
// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm
|
||||
CUTLASS_HOST_DEVICE
|
||||
int implicit_gemm_k_iterations(
|
||||
Operator conv_operator,
|
||||
int threadblock_K,
|
||||
Conv3dProblemSize const &problem_size) {
|
||||
|
||||
int iterations = 0;
|
||||
int elements_per_split_k_slice = 0;
|
||||
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
case Operator::kDgrad:
|
||||
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
case Operator::kWgrad:
|
||||
elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return iterations;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Returns ImplicitGemm tensor A extent as Tensor5DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent(
|
||||
Operator conv_operator,
|
||||
Conv3dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.activation_extent();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.output_extent();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.output_extent();
|
||||
default : break;
|
||||
}
|
||||
return cutlass::Tensor5DCoord();
|
||||
}
|
||||
|
||||
/// Returns ImplicitGemm tensor B extent as Tensor5DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent(
|
||||
Operator conv_operator,
|
||||
Conv3dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.filter_extent();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent();
|
||||
default : break;
|
||||
}
|
||||
return cutlass::Tensor5DCoord();
|
||||
}
|
||||
|
||||
/// Returns ImplicitGemm tensor C extent as Tensor5DCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent(
|
||||
Operator conv_operator,
|
||||
Conv3dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.output_extent();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent();
|
||||
default : break;
|
||||
}
|
||||
return cutlass::Tensor5DCoord();
|
||||
}
|
||||
|
||||
/// Returns ImplicitGemm tensor A size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t implicit_gemm_tensor_a_size(
|
||||
Operator conv_operator,
|
||||
Conv3dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.activation_size();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.output_size();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.output_size();
|
||||
default : break;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Returns ImplicitGemm tensor B size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t implicit_gemm_tensor_b_size(
|
||||
Operator conv_operator,
|
||||
Conv3dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.filter_size();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.filter_size();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.activation_size();
|
||||
default : break;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Returns ImplicitGemm tensor C size in number of elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t implicit_gemm_tensor_c_size(
|
||||
Operator conv_operator,
|
||||
Conv3dProblemSize const &problem_size) {
|
||||
switch (conv_operator) {
|
||||
case cutlass::conv::Operator::kFprop: return problem_size.output_size();
|
||||
case cutlass::conv::Operator::kDgrad: return problem_size.activation_size();
|
||||
case cutlass::conv::Operator::kWgrad: return problem_size.filter_size();
|
||||
default : break;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
118
include/cutlass/conv/convolution.h
Normal file
118
include/cutlass/conv/convolution.h
Normal file
@ -0,0 +1,118 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
|
||||
This file contains definitions and utility functions for describing convolution problem sizes in terms of
|
||||
activation (NHWC), filter (KRSC), output (NPQK), pading (pad_h, pad_w), stride (stride_h, stride_w),
|
||||
dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map cutlass' implicit gemm
|
||||
tensor extents, sizes, data types to that of convolutions extents, sizes, and data types.
|
||||
|
||||
* Mapping convolutions to Gemm computation *
|
||||
|
||||
Cutlass employs ImplicitGemm algorithm to implement convolutions. ImplicitGemm algorithm runs gemm operation
|
||||
on convolution tensors Activation, Filter, and Output . The underlying gemm operation follows the standard
|
||||
gemm definition:
|
||||
|
||||
C = A * B + C
|
||||
|
||||
A and B are input matrices
|
||||
C is source and output matrix
|
||||
|
||||
|
||||
For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped on
|
||||
to convolution tensors Activation, Filter and Output as per the below table:
|
||||
|
||||
___________________________________________________________________________
|
||||
ConvolutionalOperator | A | B | C
|
||||
___________________________________________________________________________
|
||||
| | | | |
|
||||
| Fprop | Activation | Filter | Output |
|
||||
| Dgrad | Output | Filter | Activation |
|
||||
| Wgrad | Output | Activation | Filter |
|
||||
___________________________________________________________________________
|
||||
|
||||
In convolution codebase, DO NOT mix using (A, B, C) with (Acvitation, Filter, Output).
|
||||
|
||||
For example, a convolution class/function with A, B, Output is confusing and error-prone. Instead use below
|
||||
mapping functions and adhere to using either A, B, C or Acvitation, Filter, Output.
|
||||
|
||||
Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap
|
||||
Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Convolutional operator
|
||||
enum class Operator {
|
||||
kFprop,
|
||||
kDgrad,
|
||||
kWgrad
|
||||
};
|
||||
|
||||
/// Distinguishes convolution from cross correlation
|
||||
enum class Mode {
|
||||
kCrossCorrelation,
|
||||
kConvolution
|
||||
};
|
||||
|
||||
/// Selects among several implementation variants trading off performance with simplicity
|
||||
enum class IteratorAlgorithm {
|
||||
kAnalytic, ///< functionally correct in all cases but lower performance
|
||||
kOptimized ///< optimized for R <= 32, S <= 32 and unity-stride dgrad
|
||||
};
|
||||
|
||||
/// Distinguishes among partial specializations that accelerate certain problems where convolution
|
||||
/// stride is unit.
|
||||
enum class StrideSupport {
|
||||
kStrided, ///< arbitrary convolution stride
|
||||
kUnity ///< unit convolution stride
|
||||
};
|
||||
|
||||
/// Identifies split-K mode
|
||||
enum class SplitKMode {
|
||||
kNone,
|
||||
kSerial,
|
||||
kParallel
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
263
include/cutlass/conv/device/implicit_gemm_convolution.h
Normal file
263
include/cutlass/conv/device/implicit_gemm_convolution.h
Normal file
@ -0,0 +1,263 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Template for device-level Implicit GEMM Convolution
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename ImplicitGemmKernel_>
|
||||
class ImplicitGemmConvolution {
|
||||
public:
|
||||
|
||||
using ImplicitGemmKernel = ImplicitGemmKernel_;
|
||||
|
||||
using ElementA = typename ImplicitGemmKernel::ElementA;
|
||||
using LayoutA = typename ImplicitGemmKernel::LayoutA;
|
||||
using ElementB = typename ImplicitGemmKernel::ElementB;
|
||||
using LayoutB = typename ImplicitGemmKernel::LayoutB;
|
||||
using ElementC = typename ImplicitGemmKernel::ElementC;
|
||||
using LayoutC = typename ImplicitGemmKernel::LayoutC;
|
||||
using ElementAccumulator = typename ImplicitGemmKernel::ElementAccumulator;
|
||||
using ElementCompute = typename ImplicitGemmKernel::ElementCompute;
|
||||
using OperatorClass = typename ImplicitGemmKernel::OperatorClass;
|
||||
using ArchTag = typename ImplicitGemmKernel::ArchTag;
|
||||
using ThreadblockShape = typename ImplicitGemmKernel::ThreadblockShape;
|
||||
using WarpShape = typename ImplicitGemmKernel::WarpShape;
|
||||
using InstructionShape = typename ImplicitGemmKernel::InstructionShape;
|
||||
using ThreadblockSwizzle = typename ImplicitGemmKernel::ThreadblockSwizzle;
|
||||
using EpilogueOutputOp = typename ImplicitGemmKernel::EpilogueOutputOp;
|
||||
static int const kStages = ImplicitGemmKernel::kStages;
|
||||
static int const kConvDim = ImplicitGemmKernel::kConvDim;
|
||||
using WarpMmaOperator = typename ImplicitGemmKernel::WarpMmaOperator;
|
||||
using ArchMmaOperator = typename ImplicitGemmKernel::ArchMmaOperator;
|
||||
using MathOperator = typename ImplicitGemmKernel::MathOperator;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmKernel::kConvolutionalOperator;
|
||||
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmKernel::kIteratorAlgorithm;
|
||||
|
||||
static int const kWarpCount =
|
||||
(ThreadblockShape::kM / WarpShape::kM) *
|
||||
(ThreadblockShape::kN / WarpShape::kN);
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename ImplicitGemmKernel::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename ImplicitGemmKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs Implicit GEMM
|
||||
ImplicitGemmConvolution() { }
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
// dispatch to iterators
|
||||
Status status = ImplicitGemmKernel::Mma::IteratorA::can_implement(args.problem_size);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = ImplicitGemmKernel::Mma::IteratorB::can_implement(args.problem_size);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(
|
||||
threadblock_swizzle.get_tiled_shape(
|
||||
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices));
|
||||
|
||||
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
||||
grid.z <= std::numeric_limits<uint16_t>::max())) {
|
||||
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size),
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices);
|
||||
|
||||
if(args.split_k_mode == SplitKMode::kParallel) {
|
||||
|
||||
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
|
||||
// The user needs to call a reduction operator to optain the final output tensor
|
||||
workspace_bytes =
|
||||
sizeof(ElementAccumulator) *
|
||||
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) *
|
||||
size_t(grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) {
|
||||
|
||||
// Split-K serial: The user workspace is used to store semaphore and serialize writing the
|
||||
// final reduced output to user's output tensor
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
if (args.problem_size.split_k_slices > 1) {
|
||||
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream);
|
||||
|
||||
if (status != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
// initialize the params structure from the arguments
|
||||
params_ = typename ImplicitGemmKernel::Params(
|
||||
args,
|
||||
static_cast<int *>(workspace)
|
||||
);
|
||||
|
||||
int smem_size = int(sizeof(typename ImplicitGemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<ImplicitGemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
cutlass::Kernel<ImplicitGemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
// update the params structure from the arguments
|
||||
params_.ptr_A = args.ref_A.data();
|
||||
params_.ptr_B = args.ref_B.data();
|
||||
params_.ptr_C = args.ref_C.data();
|
||||
params_.ptr_D = args.ref_D.data();
|
||||
params_.output_op = args.output_op;
|
||||
params_.semaphore = static_cast<int *>(workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(32 * kWarpCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename ImplicitGemmKernel::SharedStorage));
|
||||
|
||||
cutlass::Kernel<ImplicitGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
104
include/cutlass/conv/kernel/default_conv2d.h
Normal file
104
include/cutlass/conv/kernel/default_conv2d.h
Normal file
@ -0,0 +1,104 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions for threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_tile_iterator.h"
|
||||
#include "cutlass/conv/threadblock/implicit_gemm_pipelined.h"
|
||||
#include "cutlass/conv/threadblock/implicit_gemm_multistage.h"
|
||||
#include "cutlass/conv/kernel/implicit_gemm_convolution.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <
|
||||
typename ArchTag,
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename OutputOp
|
||||
>
|
||||
struct DefaultConvEpilogue {
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
OutputOp,
|
||||
OutputOp::kCount
|
||||
>::Epilogue;
|
||||
};
|
||||
|
||||
template <
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename OutputOp
|
||||
>
|
||||
struct DefaultConvEpilogue<
|
||||
arch::Sm70,
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp
|
||||
> {
|
||||
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
OutputOp,
|
||||
OutputOp::kCount
|
||||
>::Epilogue;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
1154
include/cutlass/conv/kernel/default_conv2d_dgrad.h
Normal file
1154
include/cutlass/conv/kernel/default_conv2d_dgrad.h
Normal file
File diff suppressed because it is too large
Load Diff
1379
include/cutlass/conv/kernel/default_conv2d_fprop.h
Normal file
1379
include/cutlass/conv/kernel/default_conv2d_fprop.h
Normal file
File diff suppressed because it is too large
Load Diff
928
include/cutlass/conv/kernel/default_conv2d_wgrad.h
Normal file
928
include/cutlass/conv/kernel/default_conv2d_wgrad.h
Normal file
@ -0,0 +1,928 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_tile_iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv2dWgrad;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Analytic IteratorAlgorithm and multistage
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Analytic IteratorAlgorithm and two
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Optimized IteratorAlgorithm and multistage
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Optimized IteratorAlgorithm and two
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassSimt convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Analytic IteratorAlgorithm,
|
||||
/// multi-stage pipeline, and FFMA-based mainloop for SM80
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Optimized IteratorAlgorithm,
|
||||
/// multi-stage pipeline, and FFMA-based mainloop for SM80
|
||||
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Analytic IteratorAlgorithm,
|
||||
/// 2 stage pipeline, and FFMA-based mainloop for SM50
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad specialzation for Optimized IteratorAlgorithm,
|
||||
/// 2 stage pipeline, and FFMA-based mainloop for SM50
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv2dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad
|
||||
>;
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
184
include/cutlass/conv/kernel/default_conv3d_dgrad.h
Normal file
184
include/cutlass/conv/kernel/default_conv3d_dgrad.h
Normal file
@ -0,0 +1,184 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_tile_iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Conv2dDgrad
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv3dDgrad;
|
||||
|
||||
/// Defines a kernel for Conv2dDgrad specialzation for Analytic IteratorAlgorithm Dgrad Strided
|
||||
// and multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv3dDgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport::kStrided
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA,
|
||||
StrideSupport::kStrided
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kDgrad,
|
||||
Conv3dProblemSize
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
181
include/cutlass/conv/kernel/default_conv3d_fprop.h
Normal file
181
include/cutlass/conv/kernel/default_conv3d_fprop.h
Normal file
@ -0,0 +1,181 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Conv2dFprop
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv3dFprop;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialzation for Analytic IteratorAlgorithm and multistage
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv3dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop,
|
||||
Conv3dProblemSize
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
504
include/cutlass/conv/kernel/default_conv3d_wgrad.h
Normal file
504
include/cutlass/conv/kernel/default_conv3d_wgrad.h
Normal file
@ -0,0 +1,504 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dWgrad
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv3dWgrad;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv3dWgrad specialzation for Analytic IteratorAlgorithm and multistage
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv3dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad,
|
||||
Conv3dProblemSize
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Conv3dWgrad specialzation for Analytic IteratorAlgorithm and two
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv3dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad,
|
||||
Conv3dProblemSize
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv3dWgrad specialzation for Optimized IteratorAlgorithm and multistage
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv3dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad,
|
||||
Conv3dProblemSize
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Conv3dWgrad specialzation for Optimized IteratorAlgorithm and two
|
||||
// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv3dWgrad <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor,
|
||||
ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kWgrad,
|
||||
Conv3dProblemSize
|
||||
>;
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
424
include/cutlass/conv/kernel/implicit_gemm_convolution.h
Normal file
424
include/cutlass/conv/kernel/implicit_gemm_convolution.h
Normal file
@ -0,0 +1,424 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined Implicit GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
#include "cutlass/epilogue/threadblock/output_iterator_parameter.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
|
||||
typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem
|
||||
>
|
||||
struct ImplicitGemmConvolution {
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static Operator const kConvolutionalOperator = ConvOperator;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename EpilogueOutputOp::ElementOutput;
|
||||
|
||||
/// Set output tensor C layout
|
||||
using LayoutC = LayoutA;
|
||||
|
||||
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
|
||||
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using WarpMmaOperator = typename Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename WarpMmaOperator::Shape;
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using TensorRefA = typename Mma::IteratorA::TensorRef;
|
||||
using TensorRefB = typename Mma::IteratorB::TensorRef;
|
||||
using TensorRefC = cutlass::TensorRef<ElementC, LayoutC>;
|
||||
|
||||
/// Check iterator A and B convolution dimension are the same and
|
||||
// set device::ImplicitGemmConvolution::kConvDim
|
||||
static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim,
|
||||
"Convolution on different different dimensions is not supported");
|
||||
static int const kConvDim = Mma::IteratorA::kConvDim;
|
||||
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
/// Wgrad C stride idx for implicit gemm algorithm
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
static int const kWgradCStrideIdx =
|
||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
|
||||
/// This chooses the appropriate stride element of the C tensor.
|
||||
static int const kTensorCStrideIdx =
|
||||
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter<
|
||||
LayoutC,
|
||||
typename Epilogue::OutputTileIterator::Layout,
|
||||
TensorRefC,
|
||||
ConvOperator,
|
||||
ConvProblemSize
|
||||
>;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ConvProblemSize problem_size;
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size
|
||||
):
|
||||
problem_size(problem_size) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size,
|
||||
TensorRefA const & ref_A,
|
||||
TensorRefB const & ref_B,
|
||||
TensorRefC const & ref_C,
|
||||
TensorRefC const & ref_D,
|
||||
typename EpilogueOutputOp::Params const & output_op,
|
||||
SplitKMode const & split_k_mode = SplitKMode::kSerial
|
||||
):
|
||||
problem_size(problem_size),
|
||||
ref_A(ref_A),
|
||||
ref_B(ref_B),
|
||||
ref_C(ref_C),
|
||||
ref_D(ref_D),
|
||||
output_op(output_op),
|
||||
split_k_mode(split_k_mode)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
ConvProblemSize problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
gemm::GemmCoord implicit_gemm_problem_size;
|
||||
int gemm_k_iterations;
|
||||
typename Mma::IteratorA::Params iterator_A;
|
||||
typename Mma::IteratorA::Element const *ptr_A;
|
||||
typename Mma::IteratorB::Params iterator_B;
|
||||
typename Mma::IteratorB::Element const *ptr_B;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_C;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_C;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_D;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
int *semaphore;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): gemm_k_iterations(0) { }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
int *semaphore = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
iterator_A(args.problem_size, args.ref_A.layout()),
|
||||
ptr_A(args.ref_A.data()),
|
||||
iterator_B(args.problem_size, args.ref_B.layout()),
|
||||
ptr_B(args.ref_B.data()),
|
||||
iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)),
|
||||
ptr_C(args.ref_C.data()),
|
||||
iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)),
|
||||
ptr_D(args.ref_D.data()),
|
||||
output_op(args.output_op),
|
||||
semaphore(semaphore),
|
||||
split_k_mode(args.split_k_mode)
|
||||
{
|
||||
gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
implicit_gemm_problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ImplicitGemmConvolution() { }
|
||||
|
||||
/// Executes one ImplicitGEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.iterator_A,
|
||||
params.problem_size,
|
||||
params.ptr_A,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
threadblock_tile_idx.k() * Mma::Shape::kK
|
||||
)
|
||||
);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.iterator_B,
|
||||
params.problem_size,
|
||||
params.ptr_B,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.k() * Mma::Shape::kK,
|
||||
threadblock_tile_idx.n() * Mma::Shape::kN
|
||||
)
|
||||
);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
// Construct the semaphore.
|
||||
int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m();
|
||||
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// Compute logical position within grid
|
||||
threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
threadblock_tile_idx.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.iterator_D,
|
||||
params.ptr_D,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator reading from source accumulator tensor
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.iterator_C,
|
||||
params.ptr_C,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_idx.k()) {
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_idx.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
// Each split-k-slice writes to a unique tensor location
|
||||
else if (params.split_k_mode == SplitKMode::kParallel) {
|
||||
iterator_D.add_pointer_offset(threadblock_tile_idx.k() *
|
||||
cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size));
|
||||
}
|
||||
|
||||
// Run efficient epilogue
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_idx.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,240 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorAnalytic {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or larger.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
// For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int offset_k_[ThreadMap::Iterations::kStrided];
|
||||
int offset_c_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous()
|
||||
+ c * ThreadMap::Delta::kContiguous;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_k_[s] =
|
||||
threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_k_[s] += Shape::kRow * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the filter tensor w that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int c = offset_c_[iteration_contiguous_];
|
||||
int k = offset_k_[iteration_strided_];
|
||||
|
||||
return TensorCoord(k, filter_r_, filter_s_, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the filter tensor w
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,283 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
>
|
||||
class Conv2dDgradFilterTileAccessIteratorOptimized {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = StrideSupport_;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params : Conv2dDgradFilterIteratorOptimizedParams {
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Conv2dDgradFilterIteratorOptimizedParams const &base):
|
||||
Conv2dDgradFilterIteratorOptimizedParams(base) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
):
|
||||
Conv2dDgradFilterIteratorOptimizedParams(
|
||||
problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::kElementsPerAccess,
|
||||
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
|
||||
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}
|
||||
) { }
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Conv2dDgradFilterIteratorOptimizedParams const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_;
|
||||
int filter_rs_;
|
||||
int filter_k_;
|
||||
|
||||
//
|
||||
// Assertions
|
||||
//
|
||||
|
||||
// We map predicates into bits packed in this uint32_t container
|
||||
static_assert(ThreadMap::Iterations::kStrided *
|
||||
ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8,
|
||||
"Currently, the number of loads per iteration is limited by the size of the predicates container.");
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorOptimized(
|
||||
Conv2dDgradFilterIteratorOptimizedParams const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
filter_rs_(0),
|
||||
filter_k_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.row() + thread_coord.strided();
|
||||
Index column = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided;
|
||||
int filter_c = column + c * ThreadMap::Delta::kContiguous;
|
||||
|
||||
uint32_t pred = ((filter_k < problem_size_.K && filter_c < problem_size_.C) ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_ |= (pred << pred_idx);
|
||||
}
|
||||
}
|
||||
|
||||
pointer_ += (
|
||||
filter_k_ * params.layout.stride()[2] + column
|
||||
) * sizeof_bits<Element>::value / 8;
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
LongIndex next = params_.inc_next_rs;
|
||||
|
||||
// moves to the next tile
|
||||
++filter_rs_;
|
||||
if (filter_rs_ == params_.RS) {
|
||||
|
||||
filter_rs_ = 0;
|
||||
next = params_.inc_next_k;
|
||||
filter_k_ += params_.filter_k_delta;
|
||||
}
|
||||
|
||||
// Clear predicates if needed
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
predicates_ = (predicates_ & (~kClearMask));
|
||||
}
|
||||
}
|
||||
|
||||
pointer_ += next;
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the filter tensor W
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
|
||||
return (predicates_ & (1u << pred_idx));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
return reinterpret_cast<AccessType const *>(pointer_ +
|
||||
iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
|
||||
// Move to the next K coordinate within the tile
|
||||
pointer_ += params_.inc_next_strided;
|
||||
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,525 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv2dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using
|
||||
// unscaled coordinations
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kStrided
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Simpligying assertions
|
||||
//
|
||||
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
|
||||
int offset_n_[ThreadMap::Iterations::kStrided];
|
||||
int offset_w_[ThreadMap::Iterations::kStrided];
|
||||
int offset_h_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
private:
|
||||
|
||||
/// Returns the coordinate in the output tensor Dy that is currently pointed to
|
||||
/// by the iterator but DOES NOT scale by the convolution stride. This is needed
|
||||
/// to compute predicates in the valid() method. The return value of the public at()
|
||||
/// method is correctly scaled.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord unscaled_at_() const {
|
||||
int n = offset_n_[iteration_strided_];
|
||||
int h = offset_h_[iteration_strided_];
|
||||
int w = offset_w_[iteration_strided_];
|
||||
|
||||
int r = filter_r_;
|
||||
int s = filter_s_;
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r = (problem_size_.R - 1 - r);
|
||||
s = (problem_size_.S - 1 - s);
|
||||
}
|
||||
|
||||
int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h);
|
||||
int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w);
|
||||
|
||||
return TensorCoord(n, p, q, filter_k_);
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_k_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
offset_n_[s] = offset_nhw / (problem_size_.H * problem_size_.W);
|
||||
int residual = offset_nhw % (problem_size_.H * problem_size_.W);
|
||||
|
||||
offset_h_[s] = residual / problem_size_.W;
|
||||
offset_w_[s] = residual % problem_size_.W;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// move to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
|
||||
filter_k_ += Shape_::kColumn * problem_size_.split_k_slices;
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the output tensor Dy that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
TensorCoord coord = unscaled_at_();
|
||||
|
||||
return TensorCoord(
|
||||
coord.n(),
|
||||
coord.h() / problem_size_.stride_h,
|
||||
coord.w() / problem_size_.stride_w,
|
||||
coord.c());
|
||||
}
|
||||
|
||||
|
||||
/// Returns true if the current coordinate is within the output tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord unscaled_coord = unscaled_at_();
|
||||
TensorCoord coord = at();
|
||||
|
||||
return
|
||||
!(unscaled_coord.h() % problem_size_.stride_h) && !(unscaled_coord.w() % problem_size_.stride_w) &&
|
||||
coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.P &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.Q &&
|
||||
coord.c() < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv2dDgradOutputGradientTileAccessIteratorAnalytic for unity strides can be optimized by
|
||||
// eliminating modulo arithmetic to compute unscaled coordinates
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kUnity
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Simpligying assertions
|
||||
//
|
||||
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
|
||||
int offset_n_[ThreadMap::Iterations::kStrided];
|
||||
int offset_w_[ThreadMap::Iterations::kStrided];
|
||||
int offset_h_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_k_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
offset_n_[s] = offset_nhw / (problem_size_.H * problem_size_.W);
|
||||
int residual = offset_nhw % (problem_size_.H * problem_size_.W);
|
||||
|
||||
offset_h_[s] = residual / problem_size_.W;
|
||||
offset_w_[s] = residual % problem_size_.W;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// move to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
|
||||
filter_k_ += Shape_::kColumn * problem_size_.split_k_slices;
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the output tensor Dy that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int n = offset_n_[iteration_strided_];
|
||||
int h = offset_h_[iteration_strided_];
|
||||
int w = offset_w_[iteration_strided_];
|
||||
|
||||
int r = filter_r_;
|
||||
int s = filter_s_;
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r = (problem_size_.R - 1 - r);
|
||||
s = (problem_size_.S - 1 - s);
|
||||
}
|
||||
|
||||
int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h) / problem_size_.stride_h;
|
||||
int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w) / problem_size_.stride_w;
|
||||
|
||||
return TensorCoord(n, p, q, filter_k_);
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// Returns true if the current coordinate is within the output tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.P &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.Q &&
|
||||
coord.c() < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// Conv2dDgradFilterTileAccessIteratorAnalytic unity stride specialization
|
||||
// only supports (stride_h, stride_w) = (1, 1)
|
||||
if (problem_size.stride() != MatrixCoord({1, 1})) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,437 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity
|
||||
>
|
||||
class Conv2dDgradOutputGradientTileAccessIteratorOptimized {
|
||||
public:
|
||||
|
||||
static_assert(StrideSupport_ == conv::StrideSupport::kUnity,
|
||||
"Only unit-stride dgrad is supported at this time.");
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
using Mask = uint64_t;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params : Conv2dDgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Conv2dDgradOutputGradientIteratorOptimizedParams const &base):
|
||||
Conv2dDgradOutputGradientIteratorOptimizedParams(base) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
):
|
||||
Conv2dDgradOutputGradientIteratorOptimizedParams(
|
||||
problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::kElementsPerAccess,
|
||||
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
|
||||
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}
|
||||
) { }
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
|
||||
// One pointer per access
|
||||
char const *pointer_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
// current filter position (r, s)
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int filter_k_;
|
||||
|
||||
Index masks_[ThreadMap::Iterations::kStrided][2];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorOptimized(
|
||||
Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
filter_k_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
int offset_n[ThreadMap::Iterations::kStrided];
|
||||
int offset_h[ThreadMap::Iterations::kStrided];
|
||||
int offset_w[ThreadMap::Iterations::kStrided];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
pointer_[s] = reinterpret_cast<char const *>(ptr);
|
||||
|
||||
int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
//
|
||||
// offset_n[s] = offset_nhw / (problem_size_.H * problem_size_.W);
|
||||
// int residual = offset_nhw % (problem_size_.H * problem_size_.W);
|
||||
//
|
||||
// offset_h[s] = residual / problem_size_.W;
|
||||
// offset_w[s] = residual % problem_size_.W;
|
||||
//
|
||||
|
||||
int residual;
|
||||
|
||||
params_.hw_divmod(offset_n[s], residual, offset_nhw);
|
||||
params_.w_divmod(offset_h[s], offset_w[s], residual);
|
||||
|
||||
TensorCoord coord = at_(offset_n[s], offset_h[s], offset_w[s], 0, 0);
|
||||
|
||||
pointer_[s] += params_.layout(coord) * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
clear_mask();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int r = 0; r < problem_size_.R; ++r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) {
|
||||
|
||||
int r_ = r;
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r_ = problem_size_.R - 1 - r;
|
||||
}
|
||||
|
||||
int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h;
|
||||
|
||||
bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P);
|
||||
masks_[s_idx][0] |= (pred << r);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int s = 0; s < problem_size_.S; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) {
|
||||
|
||||
int s_ = s;
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
s_ = problem_size_.S - 1 - s;
|
||||
}
|
||||
|
||||
int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w;
|
||||
|
||||
bool pred = (q >= 0 && q < problem_size_.Q);
|
||||
masks_[s_idx][1] |= (pred << s);
|
||||
}
|
||||
}
|
||||
|
||||
if (filter_k_ >= problem_size.K) {
|
||||
clear_mask();
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/// Returns the coordinate in the output gradient tensor dy that is correspoinding to
|
||||
// output nhw and filter position k, r, s
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at_(int n, int h, int w, int r, int s) const {
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r = problem_size_.R - 1 - r;
|
||||
s = problem_size_.S - 1 - s;
|
||||
}
|
||||
|
||||
int p = h + problem_size_.pad_h - r * problem_size_.dilation_h;
|
||||
int q = w + problem_size_.pad_w - s * problem_size_.dilation_w;
|
||||
|
||||
return TensorCoord(n, p, q, filter_k_);
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_byte_offset_(LongIndex byte_offset) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
pointer_[s] += byte_offset;
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask_(bool clear) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
|
||||
// artifact in which control flow instructions are generated. Instead, our
|
||||
// intent is to predicate the mov instructions.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][0])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][0])
|
||||
);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][1])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][1])
|
||||
);
|
||||
#else
|
||||
if (clear) {
|
||||
masks_[s][0] = 0;
|
||||
masks_[s][1] = 0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
add_byte_offset_(pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
int next_idx = 0;
|
||||
|
||||
// moves to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ == problem_size_.S) {
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
next_idx = 1;
|
||||
}
|
||||
else {
|
||||
filter_r_ = 0;
|
||||
next_idx = 2;
|
||||
}
|
||||
}
|
||||
|
||||
add_byte_offset_(params_.inc_next[next_idx]);
|
||||
|
||||
if (next_idx == 2) {
|
||||
filter_k_ += params_.filter_k_delta;
|
||||
}
|
||||
|
||||
clear_mask_(filter_k_ >= problem_size_.K);
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][0] = Mask(0);
|
||||
masks_[s][1] = Mask(0);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
|
||||
return
|
||||
(masks_[iteration_strided_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][1] & (Index(1) << filter_s_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() {
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// This is specialized for unit stride
|
||||
if (problem_size.stride() != MatrixCoord({1, 1})) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
// Limit on filter size
|
||||
if (problem_size.R > 32 || problem_size.S > 32) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,274 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC or TensorNCxHWx<Interleave> layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dFpropActivationTileAccessIteratorAnalytic {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_c_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
|
||||
int offset_n_[ThreadMap::Iterations::kStrided];
|
||||
int offset_p_[ThreadMap::Iterations::kStrided];
|
||||
int offset_q_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_c_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_c_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q);
|
||||
int residual = offset_npq % (problem_size_.P * problem_size_.Q);
|
||||
|
||||
offset_p_[s] = residual / problem_size_.Q;
|
||||
offset_q_[s] = residual % problem_size_.Q;
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
|
||||
filter_c_ += Shape::kColumn * problem_size_.split_k_slices;
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the activations tensor X that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
int n = offset_n_[iteration_strided_];
|
||||
int p = offset_p_[iteration_strided_];
|
||||
int q = offset_q_[iteration_strided_];
|
||||
|
||||
int r = filter_r_;
|
||||
int s = filter_s_;
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r = (problem_size_.R - 1 - filter_r_);
|
||||
s = (problem_size_.S - 1 - filter_s_);
|
||||
}
|
||||
|
||||
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
return TensorCoord(n, h, w, filter_c_);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor X
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
AccessType const *ptr = reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
if (platform::is_same<Layout, layout::TensorNCxHWx<32>>::value) {
|
||||
if (problem_size.C % 32) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
if (platform::is_same<Layout, layout::TensorNCxHWx<64>>::value) {
|
||||
if (problem_size.C % 64) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,438 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC or TensorNCxHWx<Interleave> layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dFpropActivationTileAccessIteratorOptimized {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
using Mask = uint64_t;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params : Conv2dFpropActivationIteratorOptimizedParams<Layout> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Conv2dFpropActivationIteratorOptimizedParams<Layout> const &base):
|
||||
Conv2dFpropActivationIteratorOptimizedParams<Layout>(base) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
):
|
||||
Conv2dFpropActivationIteratorOptimizedParams<Layout>(
|
||||
problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::kElementsPerAccess,
|
||||
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
|
||||
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}
|
||||
) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Conv2dFpropActivationIteratorOptimizedParams<Layout> const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
|
||||
// One pointer per access
|
||||
char const *pointer_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
// current filter position (r, s)
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int filter_c_;
|
||||
|
||||
Index masks_[ThreadMap::Iterations::kStrided][2];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationTileAccessIteratorOptimized(
|
||||
Conv2dFpropActivationIteratorOptimizedParams<Layout> const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
filter_c_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_c_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
int offset_n[ThreadMap::Iterations::kStrided];
|
||||
int offset_p[ThreadMap::Iterations::kStrided];
|
||||
int offset_q[ThreadMap::Iterations::kStrided];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
pointer_[s] = reinterpret_cast<char const *>(ptr);
|
||||
|
||||
int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
//
|
||||
// offset_n[s] = offset_npq / (problem_size_.P * problem_size_.Q);
|
||||
// int residual = offset_npq % (problem_size_.P * problem_size_.Q);
|
||||
//
|
||||
// offset_p[s] = residual / problem_size_.Q;
|
||||
// offset_q[s] = residual % problem_size_.Q;
|
||||
//
|
||||
|
||||
int residual;
|
||||
|
||||
params.pq_divmod(offset_n[s], residual, offset_npq);
|
||||
params.q_divmod(offset_p[s], offset_q[s], residual);
|
||||
|
||||
TensorCoord coord = at_(offset_n[s], offset_p[s], offset_q[s], 0, 0);
|
||||
|
||||
pointer_[s] += params_.layout(coord) * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
clear_mask();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int r = 0; r < problem_size_.R; ++r) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) {
|
||||
|
||||
int r_ = r;
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r_ = problem_size_.R - 1 - r;
|
||||
}
|
||||
|
||||
int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h;
|
||||
|
||||
bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H);
|
||||
masks_[s_idx][0] |= (pred << r);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int s = 0; s < problem_size_.S; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) {
|
||||
|
||||
int s_ = s;
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
s_ = problem_size_.S - 1 - s;
|
||||
}
|
||||
|
||||
int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w;
|
||||
|
||||
bool pred = (w >= 0 && w < problem_size_.W);
|
||||
masks_[s_idx][1] |= (pred << s);
|
||||
}
|
||||
}
|
||||
|
||||
if (filter_c_ >= problem_size.C) {
|
||||
clear_mask();
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/// Returns the coordinate in the activations tensor X that is correspoinding to
|
||||
// output npq and filter position r, s
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at_(int n, int p, int q, int r, int s) const {
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r = problem_size_.R - 1 - r;
|
||||
s = problem_size_.S - 1 - s;
|
||||
}
|
||||
|
||||
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
return TensorCoord(n, h, w, filter_c_);
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_byte_offset_(LongIndex byte_offset) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
pointer_[s] += byte_offset;
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask_(bool clear) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
// We are using inline PTX assembly here to avoid an CUDA C++ compilation
|
||||
// artifact in which control flow instructions are generated. Instead, our
|
||||
// intent is to predicate the mov instructions.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][0])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][0])
|
||||
);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" .reg .u32 m;"
|
||||
" mov.u32 m, %2;"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p mov.u32 m, 0;\n"
|
||||
" mov.u32 %0, m;\n"
|
||||
"}\n"
|
||||
:
|
||||
"=r"(masks_[s][1])
|
||||
:
|
||||
"r"((int)clear),
|
||||
"r"(masks_[s][1])
|
||||
);
|
||||
#else
|
||||
if (clear) {
|
||||
masks_[s][0] = 0;
|
||||
masks_[s][1] = 0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
add_byte_offset_(pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
int next_idx = 0;
|
||||
|
||||
// moves to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ == problem_size_.S) {
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
next_idx = 1;
|
||||
}
|
||||
else {
|
||||
filter_r_ = 0;
|
||||
next_idx = 2;
|
||||
}
|
||||
}
|
||||
|
||||
add_byte_offset_(params_.inc_next[next_idx]);
|
||||
|
||||
if (next_idx == 2) {
|
||||
filter_c_ += params_.filter_c_delta;
|
||||
}
|
||||
|
||||
clear_mask_(filter_c_ >= problem_size_.C);
|
||||
}
|
||||
|
||||
/// Clears the predicates
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
masks_[s][0] = Mask(0);
|
||||
masks_[s][1] = Mask(0);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
|
||||
return
|
||||
(masks_[iteration_strided_][0] & (Index(1) << filter_r_)) &&
|
||||
(masks_[iteration_strided_][1] & (Index(1) << filter_s_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_[iteration_strided_]);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationTileAccessIteratorOptimized &operator++() {
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
if (platform::is_same<Layout, layout::TensorNCxHWx<32>>::value) {
|
||||
if (problem_size.C % 32) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
if (platform::is_same<Layout, layout::TensorNCxHWx<64>>::value) {
|
||||
if (problem_size.C % 64) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
// Conv2dFpropActivationTileAccessIteratorOptimized has constraint on filter positions
|
||||
// due to the number of mask bits.
|
||||
if (problem_size.R > 32 || problem_size.S > 32) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,252 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC or TensorCxRSKx<Interleave> layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dFpropFilterTileAccessIteratorAnalytic {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int filter_c_;
|
||||
|
||||
int offset_k_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_r_(0),
|
||||
filter_s_(0),
|
||||
filter_c_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_c_ = threadblock_offset.row() + thread_coord.contiguous();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * 8 / sizeof_bits<Element>::value;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
filter_s_ = 0;
|
||||
|
||||
++filter_r_;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
|
||||
filter_c_ += Shape::kRow * problem_size_.split_k_slices;
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the filter tensor W that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int k = offset_k_[iteration_strided_];
|
||||
|
||||
return TensorCoord(k, filter_r_, filter_s_, filter_c_);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor W
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K &&
|
||||
coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
if (platform::is_same<Layout, layout::TensorCxRSKx<32>>::value) {
|
||||
if (problem_size.K % 32) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
if (platform::is_same<Layout, layout::TensorCxRSKx<64>>::value) {
|
||||
if (problem_size.K % 64) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,282 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC or TensorCxRSKx<Interleave> layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dFpropFilterTileAccessIteratorOptimized{
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params : Conv2dFpropFilterIteratorOptimizedParams<Layout> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Conv2dFpropFilterIteratorOptimizedParams<Layout> const &base):
|
||||
Conv2dFpropFilterIteratorOptimizedParams<Layout>(base) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
):
|
||||
Conv2dFpropFilterIteratorOptimizedParams<Layout>(
|
||||
problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::kElementsPerAccess,
|
||||
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
|
||||
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}
|
||||
) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Conv2dFpropFilterIteratorOptimizedParams<Layout> const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_;
|
||||
int filter_rs_;
|
||||
int filter_c_;
|
||||
|
||||
//
|
||||
// Assertions
|
||||
//
|
||||
|
||||
// We map predicates into bits packed in this uint32_t container
|
||||
static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8,
|
||||
"Currently, the number of loads per iteration is limited by the size of the predicates container.");
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterTileAccessIteratorOptimized(
|
||||
Conv2dFpropFilterIteratorOptimizedParams<Layout> const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
filter_rs_(0),
|
||||
filter_c_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_c_ = threadblock_offset.row() + thread_coord.contiguous();
|
||||
Index column = threadblock_offset.column() + thread_coord.strided();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0);
|
||||
predicates_ |= (pred << s);
|
||||
}
|
||||
|
||||
if (filter_c_ >= problem_size.C) {
|
||||
predicates_ = 0u;
|
||||
}
|
||||
|
||||
pointer_ += (
|
||||
params_.layout({filter_c_, column})
|
||||
) * sizeof_bits<Element>::value / 8;
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
LongIndex next = params_.inc_next_rs;
|
||||
|
||||
// moves to the next tile
|
||||
++filter_rs_;
|
||||
if (filter_rs_ == params_.RS) {
|
||||
|
||||
filter_rs_ = 0;
|
||||
next = params_.inc_next_c;
|
||||
filter_c_ += params_.filter_c_delta;
|
||||
}
|
||||
|
||||
if (filter_c_ >= problem_size_.C) {
|
||||
predicates_ = 0;
|
||||
}
|
||||
|
||||
pointer_ += next;
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the filter tensor W
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() {
|
||||
return (predicates_ & (1u << iteration_strided_));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
return reinterpret_cast<AccessType const *>(pointer_);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
|
||||
// Move to the next K coordinate within the tile
|
||||
pointer_ += params_.inc_next_k;
|
||||
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
if (platform::is_same<Layout, layout::TensorCxRSKx<32>>::value) {
|
||||
if (problem_size.K % 32) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
if (platform::is_same<Layout, layout::TensorCxRSKx<64>>::value) {
|
||||
if (problem_size.K % 64) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
609
include/cutlass/conv/threadblock/conv2d_params.h
Normal file
609
include/cutlass/conv/threadblock/conv2d_params.h
Normal file
@ -0,0 +1,609 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Extracts the host-params objects into non-template code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED
|
||||
#include <fstream>
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Params structure used for all Conv2d analytic tile iterators
|
||||
template< typename Layout_ = layout::TensorNHWC >
|
||||
struct Conv2dAnalyticParams {
|
||||
|
||||
using Layout = Layout_;
|
||||
|
||||
Layout layout;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dAnalyticParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dAnalyticParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void TraceIteratorParams(
|
||||
char const *conv_operator,
|
||||
char const *operand,
|
||||
int element_size_bits,
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
) {
|
||||
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
|
||||
char const *fname = "conv_iterator_params.csv";
|
||||
|
||||
std::ifstream test(fname);
|
||||
bool file_exists = test.is_open();
|
||||
|
||||
if (file_exists) {
|
||||
test.close();
|
||||
}
|
||||
|
||||
std::ofstream trace("conv_iterator_params.csv", std::ofstream::app);
|
||||
|
||||
if (!file_exists) {
|
||||
trace
|
||||
<< "Operator,Operand,ElementSize,CtaRows,CtaColumns,ThreadCount,AccessSize,"
|
||||
<< "IterationsContiguous,IterationsStrided,DeltaContiguous,DeltaStrided\n";
|
||||
}
|
||||
|
||||
trace << conv_operator << "," << operand << "," << element_size_bits << ","
|
||||
<< threadblock_shape.row() << "," << threadblock_shape.column()
|
||||
<< "," << thread_count << "," << access_size
|
||||
<< "," << threadmap_iterations.contiguous() << "," << threadmap_iterations.strided()
|
||||
<< "," << threadmap_delta.contiguous() << "," << threadmap_delta.strided() << "\n";
|
||||
#endif
|
||||
}
|
||||
|
||||
#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) \
|
||||
TraceIteratorParams(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta);
|
||||
|
||||
#else
|
||||
|
||||
#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) {}
|
||||
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized
|
||||
template< typename Layout_ = layout::TensorNHWC >
|
||||
struct Conv2dFpropActivationIteratorOptimizedParams;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized
|
||||
template<>
|
||||
struct Conv2dFpropActivationIteratorOptimizedParams<layout::TensorNHWC> {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
|
||||
int64_t inc_next[3]; // {next S, next R, next C}
|
||||
int filter_c_delta; // number of logical elements to add to filter_c_
|
||||
int PQ; // product of P*Q
|
||||
|
||||
FastDivmod pq_divmod;
|
||||
FastDivmod q_divmod;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout, ///< layout object
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
):
|
||||
layout(layout), PQ(problem_size.P * problem_size.Q), pq_divmod(PQ), q_divmod(problem_size.Q) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1);
|
||||
|
||||
// next S
|
||||
inc_next[0] = conv_sign * (int64_t(layout.stride()[0]) * problem_size.dilation_w) * element_size_bits / 8;
|
||||
|
||||
// next R
|
||||
inc_next[1] = conv_sign * (
|
||||
int64_t(layout.stride()[1]) * problem_size.dilation_h
|
||||
- (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next C
|
||||
inc_next[2] = (
|
||||
threadblock_shape.column() * problem_size.split_k_slices
|
||||
- conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h
|
||||
- conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// logical offset added to internal channel counter - units are elements, not bytes
|
||||
filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices;
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized
|
||||
template <int Interleaved_>
|
||||
struct Conv2dFpropActivationIteratorOptimizedParams<layout::TensorNCxHWx<Interleaved_>> {
|
||||
static int const kInterleaved = Interleaved_;
|
||||
|
||||
using Layout = layout::TensorNCxHWx<kInterleaved>;
|
||||
|
||||
Layout layout;
|
||||
|
||||
int64_t inc_next[3]; // {next S, next R, next C}
|
||||
int filter_c_delta; // number of logical elements to add to filter_c_
|
||||
int PQ; // product of P*Q
|
||||
|
||||
FastDivmod pq_divmod;
|
||||
FastDivmod q_divmod;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropActivationIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout, ///< layout object
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
):
|
||||
layout(layout), PQ(problem_size.P * problem_size.Q), pq_divmod(PQ), q_divmod(problem_size.Q) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1);
|
||||
|
||||
// next S
|
||||
inc_next[0] = conv_sign * (kInterleaved * problem_size.dilation_w) * element_size_bits / 8;
|
||||
|
||||
// next R
|
||||
inc_next[1] = conv_sign * (
|
||||
int64_t(layout.stride()[0]) * problem_size.dilation_h
|
||||
- (problem_size.S - 1) * kInterleaved * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next C
|
||||
inc_next[2] = (
|
||||
threadblock_shape.column() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[1])
|
||||
- conv_sign * int64_t(problem_size.R - 1) * layout.stride()[0] * problem_size.dilation_h
|
||||
- conv_sign * int64_t(problem_size.S - 1) * kInterleaved * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// logical offset added to internal channel counter - units are elements, not bytes
|
||||
filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Layout_ = layout::TensorNHWC >
|
||||
struct Conv2dFpropFilterIteratorOptimizedParams;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Conv2dFpropFilterIteratorOptimizedParams<layout::TensorNHWC>
|
||||
{
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
int RS;
|
||||
int filter_c_delta;
|
||||
|
||||
int64_t inc_next_k; // offset in units of bytes to next K position
|
||||
int64_t inc_next_rs; // offset in units of bytes to next RS position
|
||||
int64_t inc_next_c; // offset in units of bytes to next C position
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout,
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
):
|
||||
layout(layout) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
RS = problem_size.R * problem_size.S;
|
||||
|
||||
inc_next_k = (int64_t(layout.stride()[2]) * threadmap_delta.strided() * element_size_bits) / 8;
|
||||
|
||||
inc_next_rs =
|
||||
( int64_t(layout.stride()[0])
|
||||
- int64_t(layout.stride()[2]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided()
|
||||
) * element_size_bits / 8;
|
||||
|
||||
inc_next_c =
|
||||
(
|
||||
threadblock_shape.row() * problem_size.split_k_slices
|
||||
- int64_t(RS - 1) * layout.stride()[0]
|
||||
- int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices;
|
||||
}
|
||||
};
|
||||
|
||||
template<int Interleaved_>
|
||||
struct Conv2dFpropFilterIteratorOptimizedParams<layout::TensorCxRSKx<Interleaved_>>
|
||||
{
|
||||
static int const kInterleaved = Interleaved_;
|
||||
using Layout = layout::TensorCxRSKx<kInterleaved>;
|
||||
|
||||
Layout layout;
|
||||
int RS;
|
||||
int filter_c_delta;
|
||||
|
||||
int64_t inc_next_k; // offset in units of bytes to next K position
|
||||
int64_t inc_next_rs; // offset in units of bytes to next RS position
|
||||
int64_t inc_next_c; // offset in units of bytes to next C position
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dFpropFilterIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout,
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
):
|
||||
layout(layout) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
RS = problem_size.R * problem_size.S;
|
||||
|
||||
inc_next_k = (kInterleaved * threadmap_delta.strided() * element_size_bits) / 8;
|
||||
|
||||
inc_next_rs =
|
||||
( int64_t(layout.stride()[0])
|
||||
- kInterleaved * (threadmap_iterations.strided() - 1) * threadmap_delta.strided()
|
||||
) * element_size_bits / 8;
|
||||
|
||||
inc_next_c =
|
||||
(
|
||||
threadblock_shape.row() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[2])
|
||||
- int64_t(RS - 1) * layout.stride()[0]
|
||||
- int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * kInterleaved
|
||||
) * element_size_bits / 8;
|
||||
|
||||
filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices;
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters object for Conv2d DGRAD OutputGradient (dy) iterator
|
||||
struct Conv2dDgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
|
||||
int64_t inc_next[3]; // {next S, next R, next K}
|
||||
|
||||
int filter_k_delta; // number of logical elements to add to filter_k_
|
||||
|
||||
int HW; // product of H*W
|
||||
|
||||
FastDivmod hw_divmod;
|
||||
FastDivmod w_divmod;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradOutputGradientIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout,
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
):
|
||||
layout(layout), HW(problem_size.H *problem_size.W), hw_divmod(HW), w_divmod(problem_size.W) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv2d_dgrad", "output_gradient",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1);
|
||||
|
||||
// next S
|
||||
inc_next[0] = conv_sign * (layout.stride()[0] * problem_size.dilation_w) * element_size_bits / 8;
|
||||
|
||||
// next R
|
||||
inc_next[1] = conv_sign * (
|
||||
layout.stride()[1] * problem_size.dilation_h
|
||||
- (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next K
|
||||
inc_next[2] = (
|
||||
threadblock_shape.column() * problem_size.split_k_slices
|
||||
- conv_sign * (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h
|
||||
- conv_sign * (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// logical offset added to internal channel counter - units are elements, not bytes
|
||||
filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices;
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters object for Conv2d DGRAD Filter (w) iterator
|
||||
struct Conv2dDgradFilterIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
int RS;
|
||||
int filter_k_delta;
|
||||
|
||||
int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile
|
||||
int64_t inc_next_rs; // offset in units of bytes to next RS position
|
||||
int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dDgradFilterIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout,
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
):
|
||||
layout(layout), RS(problem_size.R * problem_size.S) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8;
|
||||
|
||||
inc_next_rs =
|
||||
( layout.stride()[0]
|
||||
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
inc_next_k =
|
||||
(
|
||||
threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2]
|
||||
- (problem_size.R * problem_size.S - 1) * layout.stride()[0]
|
||||
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters object for Conv2d WGRAD Output Gradient (dy) iterator
|
||||
struct Conv2dWgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
|
||||
int NPQ; // precomputd product of N*P*Q for clearing predicates
|
||||
|
||||
FastDivmod pq_divmod;
|
||||
FastDivmod q_divmod;
|
||||
|
||||
int64_t offset_next_strided; // offset in units of bytes to next npq coordinate within tile
|
||||
int64_t offset_next_contiguous; // offset in units of bytes to next k coordinate within tile
|
||||
int64_t inc_next_npq; // offset in units of bytes to next npq position in subsequent tile
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradOutputGradientIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradOutputGradientIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout,
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
):
|
||||
layout(layout),
|
||||
NPQ(problem_size.N * problem_size.P * problem_size.Q),
|
||||
pq_divmod(problem_size.P * problem_size.Q),
|
||||
q_divmod(problem_size.Q) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv2d_wgrad", "output_gradient",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
// Incremental offsets in unites of bytes (number of elements) * sizeof_bits<Element>::value / 8
|
||||
offset_next_strided = (threadmap_delta.strided() * layout.stride()[0])
|
||||
* element_size_bits / 8;
|
||||
|
||||
offset_next_contiguous = (threadmap_delta.contiguous())
|
||||
* element_size_bits / 8;
|
||||
|
||||
inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0])
|
||||
* element_size_bits / 8;
|
||||
}
|
||||
};
|
||||
|
||||
struct Conv2dWgradActivationIteratorOptimizedParams {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
|
||||
FastDivmod sc_divmod;
|
||||
FastDivmod pq_divmod;
|
||||
FastDivmod q_divmod;
|
||||
FastDivmod c_divmod;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationIteratorOptimizedParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
):
|
||||
layout(layout),
|
||||
sc_divmod(problem_size.S * problem_size.C),
|
||||
pq_divmod(problem_size.P * problem_size.Q),
|
||||
q_divmod(problem_size.Q),
|
||||
c_divmod(problem_size.C) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationIteratorOptimizedParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout,
|
||||
int element_size_bits, ///< size of each element in bits
|
||||
MatrixCoord threadblock_shape,
|
||||
int thread_count,
|
||||
int access_size,
|
||||
layout::PitchLinearCoord threadmap_iterations,
|
||||
layout::PitchLinearCoord threadmap_delta
|
||||
):
|
||||
Conv2dWgradActivationIteratorOptimizedParams(
|
||||
problem_size,
|
||||
layout
|
||||
) {
|
||||
|
||||
TRACE_CONV_INITIALIZERS("conv2d_wgrad", "activation",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
170
include/cutlass/conv/threadblock/conv2d_tile_iterator.h
Normal file
170
include/cutlass/conv/threadblock/conv2d_tile_iterator.h
Normal file
@ -0,0 +1,170 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template wraps the tile access iterator concept to load whole tiles from tensors in
|
||||
memory used for implicit GEMM convolution.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileAccessIterator_>
|
||||
class TileIterator {
|
||||
public:
|
||||
using TileAccessIterator = TileAccessIterator_;
|
||||
|
||||
using Shape = typename TileAccessIterator::Shape;
|
||||
using Element = typename TileAccessIterator::Element;
|
||||
using Layout = typename TileAccessIterator::Layout;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = typename TileAccessIterator::ThreadMap;
|
||||
using AccessType = typename TileAccessIterator::AccessType;
|
||||
using TensorRef = typename TileAccessIterator::TensorRef;
|
||||
using Index = typename TileAccessIterator::Index;
|
||||
using LongIndex = typename TileAccessIterator::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm;
|
||||
static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport;
|
||||
using Params = typename TileAccessIterator::Params;
|
||||
static int const kConvDim = TileAccessIterator::kConvDim;
|
||||
using ConvProblemSize = typename TileAccessIterator::ConvProblemSize;
|
||||
|
||||
/// Fragment object to be loaded or stored
|
||||
using Fragment = cutlass::Array<
|
||||
Element,
|
||||
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
|
||||
/// Internal state
|
||||
TileAccessIterator tile_access_iterator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIterator(
|
||||
Params const ¶ms,
|
||||
ConvProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
tile_access_iterator_(params, problem_size, ptr, thread_idx, threadblock_offset) { }
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
tile_access_iterator_.add_pointer_offset(pointer_offset);
|
||||
}
|
||||
|
||||
/// Advances to the next tile in memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIterator &operator++() {
|
||||
tile_access_iterator_.advance();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances to the next tile in memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIterator operator++(int) {
|
||||
TileIterator self(*this);
|
||||
operator++();
|
||||
return self;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
||||
|
||||
frag.clear();
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[c + s * ThreadMap::Iterations::kContiguous],
|
||||
tile_access_iterator_.get() + pointer_offset,
|
||||
tile_access_iterator_.valid()
|
||||
);
|
||||
|
||||
++tile_access_iterator_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
tile_access_iterator_.set_iteration_index(0);
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance() {
|
||||
tile_access_iterator_.advance();
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(ConvProblemSize const &problem_size) {
|
||||
|
||||
// dispatch to iterator implementation
|
||||
return TileAccessIterator::can_implement(problem_size);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,254 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dWgradActivationTileAccessIteratorAnalytic {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
// Filter postion (r,s,c) in contiguous dimension stays constant for each gemm_iteration_k
|
||||
int filter_r_[ThreadMap::Iterations::kContiguous];
|
||||
int filter_s_[ThreadMap::Iterations::kContiguous];
|
||||
int filter_c_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
int offset_npq_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr))
|
||||
{
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
// initialize r,s,c filter position for every contiguous iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
int rsc_offset = threadblock_offset.column() + thread_coord.contiguous()
|
||||
+ c * ThreadMap::Delta::kContiguous;
|
||||
|
||||
filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C);
|
||||
int residual = rsc_offset % (problem_size_.S * problem_size_.C);
|
||||
|
||||
filter_s_[c] = residual / problem_size_.C;
|
||||
filter_c_[c] = residual % problem_size_.C;
|
||||
}
|
||||
|
||||
// initialize n, p, q offset for every strided iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
offset_npq_[s] = threadblock_offset.row() + thread_coord.strided()
|
||||
+ s * ThreadMap::Delta::kStrided;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
// moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the activation tensor x that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int r = filter_r_[iteration_contiguous_];
|
||||
int s = filter_s_[iteration_contiguous_];
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r = (problem_size_.R - 1 - r);
|
||||
s = (problem_size_.S - 1 - s);
|
||||
}
|
||||
|
||||
int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q);
|
||||
int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q);
|
||||
|
||||
int p = residual / problem_size_.Q;
|
||||
int q = residual % problem_size_.Q;
|
||||
|
||||
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
return TensorCoord(n, h, w, filter_c_[iteration_contiguous_]);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activation tensor x
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,273 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dWgradActivationTileAccessIteratorOptimized {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dWgradActivationIteratorOptimizedParams;
|
||||
|
||||
private:
|
||||
|
||||
Conv2dWgradActivationIteratorOptimizedParams const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
// Precomputed effective filter postion (r,s) in contiguous dimension stays constant for each gemm_iteration_k
|
||||
// required for npq -> nhw translation
|
||||
int precomputed_filter_r_[ThreadMap::Iterations::kContiguous];
|
||||
int precomputed_filter_s_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
// Channel dimension in contiguous dimension stays constant for each gemm_iteration_k
|
||||
int filter_c_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
int offset_npq_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationTileAccessIteratorOptimized(
|
||||
Conv2dWgradActivationIteratorOptimizedParams const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr))
|
||||
{
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
// initialize r,s,c filter position for every contiguous iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
int rsc_offset = threadblock_offset.column() + thread_coord.contiguous()
|
||||
+ c * ThreadMap::Delta::kContiguous;
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
//
|
||||
// filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C);
|
||||
// int residual = rsc_offset % (problem_size_.S * problem_size_.C);
|
||||
//
|
||||
// filter_s_[c] = residual / problem_size_.C;
|
||||
// filter_c_[c] = residual % problem_size_.C;
|
||||
|
||||
int residual;
|
||||
params_.sc_divmod(precomputed_filter_r_[c], residual, rsc_offset);
|
||||
params_.c_divmod(precomputed_filter_s_[c], filter_c_[c], residual);
|
||||
|
||||
int r = precomputed_filter_r_[c];
|
||||
int s = precomputed_filter_s_[c];
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
r = (problem_size_.R - 1 - r);
|
||||
s = (problem_size_.S - 1 - s);
|
||||
}
|
||||
|
||||
precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
}
|
||||
|
||||
// initialize n, p, q offset for every strided iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
offset_npq_[s] = threadblock_offset.row() + thread_coord.strided()
|
||||
+ s * ThreadMap::Delta::kStrided;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
// moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the activation tensor x that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
//
|
||||
// int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q);
|
||||
// int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q);
|
||||
//
|
||||
// int p = residual / problem_size_.Q;
|
||||
// int q = residual % problem_size_.Q;
|
||||
|
||||
int residual, n, p, q;
|
||||
|
||||
params_.pq_divmod(n, residual, offset_npq_[iteration_strided_]);
|
||||
params_.q_divmod(p, q, residual);
|
||||
|
||||
int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_];
|
||||
int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_];
|
||||
|
||||
return TensorCoord(n, h, w, filter_c_[iteration_contiguous_]);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activation tensor x
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradActivationTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,234 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dWgradOutputGradientTileAccessIteratorAnalytic {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Conv2dAnalyticParams<Layout>;
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
int offset_npq_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradOutputGradientTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
// initialize filter_k for every contiguous iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous()
|
||||
+ c * ThreadMap::Delta::kContiguous;
|
||||
}
|
||||
|
||||
// initialize n, p, q offset for every strided iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_npq_[s] = threadblock_offset.column() + thread_coord.strided()
|
||||
+ s * ThreadMap::Delta::kStrided;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_npq_[s] += Shape::kColumn * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the output gradient tensor Dy that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int npq = offset_npq_[iteration_strided_];
|
||||
|
||||
int n = npq / (problem_size_.P * problem_size_.Q);
|
||||
int residual = npq % (problem_size_.P * problem_size_.Q);
|
||||
|
||||
int p = residual / problem_size_.Q;
|
||||
int q = residual % problem_size_.Q;
|
||||
|
||||
return TensorCoord(n, p, q, filter_k_[iteration_contiguous_]);
|
||||
}
|
||||
|
||||
|
||||
/// Returns true if the current coordinate is within the output gradient tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.h() < problem_size_.P &&
|
||||
coord.w() < problem_size_.Q &&
|
||||
coord.c() < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,300 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv2dWgradOutputGradientTileAccessIteratorOptimized {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params : Conv2dWgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Conv2dWgradOutputGradientIteratorOptimizedParams const &base):
|
||||
Conv2dWgradOutputGradientIteratorOptimizedParams(base) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
):
|
||||
Conv2dWgradOutputGradientIteratorOptimizedParams(
|
||||
problem_size,
|
||||
layout,
|
||||
sizeof_bits<Element>::value,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::kElementsPerAccess,
|
||||
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
|
||||
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}
|
||||
) { }
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_;
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_;
|
||||
int filter_k_;
|
||||
int offset_npq_;
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradOutputGradientTileAccessIteratorOptimized(
|
||||
Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
filter_k_(0),
|
||||
offset_npq_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.row() + thread_coord.contiguous();
|
||||
offset_npq_ = threadblock_offset.column() + thread_coord.strided();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous;
|
||||
int offset_npq = offset_npq_ + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
bool predicate = valid_(at_(offset_npq, filter_k));
|
||||
|
||||
uint32_t pred = (predicate ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_ |= (pred << pred_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0)
|
||||
pointer_ += (
|
||||
offset_npq_ * params.layout.stride()[0] + filter_k_
|
||||
) * sizeof_bits<Element>::value / 8;
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile
|
||||
offset_npq_ += Shape::kColumn * problem_size_.split_k_slices;
|
||||
|
||||
// Clear predicates if needed
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (offset_npq_ + s * ThreadMap::Delta::kStrided >= params_.NPQ) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
predicates_ = (predicates_ & (~kClearMask));
|
||||
}
|
||||
}
|
||||
|
||||
pointer_ += params_.inc_next_npq;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Returns the coordinate in the output gradient tensor Dy that is pointed to
|
||||
/// by offset_npq and k.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at_(int offset_npq, int k) const {
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
//
|
||||
// int npq = offset_npq;
|
||||
// int n = npq / (problem_size_.P * problem_size_.Q);
|
||||
// int residual = npq % (problem_size_.P * problem_size_.Q);
|
||||
//
|
||||
// int p = residual / problem_size_.Q;
|
||||
// int q = residual % problem_size_.Q;
|
||||
|
||||
int residual, n, p, q;
|
||||
|
||||
params_.pq_divmod(n, residual, offset_npq);
|
||||
params_.q_divmod(p, q, residual);
|
||||
|
||||
return TensorCoord(n, p, q, k);
|
||||
}
|
||||
|
||||
/// Returns true if the coord is within the output gradient tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid_(TensorCoord coord) const {
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.c() < problem_size_.K;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Returns true if the current coordinate is within the output gradient tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
|
||||
return (predicates_ & (1u << pred_idx));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType const *>(
|
||||
pointer_ +
|
||||
iteration_strided_ * params_.offset_next_strided +
|
||||
iteration_contiguous_ * params_.offset_next_contiguous
|
||||
);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv2dWgradOutputGradientTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,263 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNDHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv3dDgradFilterTileAccessIteratorAnalytic {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNDHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or larger.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv3dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
// For a fixed filter position (t,r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension
|
||||
int filter_t_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int offset_k_[ThreadMap::Iterations::kStrided];
|
||||
int offset_c_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dDgradFilterTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_t_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous()
|
||||
+ c * ThreadMap::Delta::kContiguous;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_k_[s] =
|
||||
threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
++filter_t_;
|
||||
if (filter_t_ < problem_size_.T) {
|
||||
return;
|
||||
}
|
||||
filter_t_ = 0;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_k_[s] += Shape::kRow * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the filter tensor w that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int c = offset_c_[iteration_contiguous_];
|
||||
int k = offset_k_[iteration_strided_];
|
||||
|
||||
return TensorCoord(k, filter_t_, filter_r_, filter_s_, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the filter tensor w
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dDgradFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv3dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,331 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNDHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_,
|
||||
conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided
|
||||
>
|
||||
class Conv3dDgradOutputGradientTileAccessIteratorAnalytic;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conv3dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using
|
||||
// unscaled coordinations
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv3dDgradOutputGradientTileAccessIteratorAnalytic <
|
||||
Shape_,
|
||||
Element_,
|
||||
ThreadMap_,
|
||||
conv::StrideSupport::kStrided
|
||||
> {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNDHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"DGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Simpligying assertions
|
||||
//
|
||||
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ConvProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
ConvProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_;
|
||||
int filter_t_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
|
||||
int offset_n_[ThreadMap::Iterations::kStrided];
|
||||
int offset_d_[ThreadMap::Iterations::kStrided];
|
||||
int offset_w_[ThreadMap::Iterations::kStrided];
|
||||
int offset_h_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
private:
|
||||
|
||||
/// Returns the coordinate in the output tensor Dy that is currently pointed to
|
||||
/// by the iterator but DOES NOT scale by the convolution stride. This is needed
|
||||
/// to compute predicates in the valid() method. The return value of the public at()
|
||||
/// method is correctly scaled.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord unscaled_at_() const {
|
||||
int n = offset_n_[iteration_strided_];
|
||||
int d = offset_d_[iteration_strided_];
|
||||
int h = offset_h_[iteration_strided_];
|
||||
int w = offset_w_[iteration_strided_];
|
||||
|
||||
int t = filter_t_;
|
||||
int r = filter_r_;
|
||||
int s = filter_s_;
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
t = (problem_size_.T - 1 - t);
|
||||
r = (problem_size_.R - 1 - r);
|
||||
s = (problem_size_.S - 1 - s);
|
||||
}
|
||||
|
||||
int z = (d + problem_size_.pad_d - t * problem_size_.dilation_d);
|
||||
int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h);
|
||||
int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w);
|
||||
|
||||
return TensorCoord(n, z, p, q, filter_k_);
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dDgradOutputGradientTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
ConvProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_k_(0),
|
||||
filter_t_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
offset_n_[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W);
|
||||
int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W);
|
||||
|
||||
offset_d_[s] = residual / (problem_size_.H * problem_size_.W);
|
||||
residual = residual % (problem_size_.H * problem_size_.W);
|
||||
|
||||
offset_h_[s] = residual / problem_size_.W;
|
||||
offset_w_[s] = residual % problem_size_.W;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// move to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
++filter_t_;
|
||||
if (filter_t_ < problem_size_.T) {
|
||||
return;
|
||||
}
|
||||
filter_t_ = 0;
|
||||
|
||||
filter_k_ += Shape_::kColumn * problem_size_.split_k_slices;
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the output tensor Dy that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
TensorCoord coord = unscaled_at_();
|
||||
|
||||
return TensorCoord(
|
||||
coord.n(),
|
||||
coord.d() / problem_size_.stride_d,
|
||||
coord.h() / problem_size_.stride_h,
|
||||
coord.w() / problem_size_.stride_w,
|
||||
coord.c());
|
||||
}
|
||||
|
||||
|
||||
/// Returns true if the current coordinate is within the output tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord unscaled_coord = unscaled_at_();
|
||||
TensorCoord coord = at();
|
||||
|
||||
return
|
||||
!(unscaled_coord.d() % problem_size_.stride_d) &&
|
||||
!(unscaled_coord.h() % problem_size_.stride_h) &&
|
||||
!(unscaled_coord.w() % problem_size_.stride_w) &&
|
||||
coord.n() < problem_size_.N &&
|
||||
coord.d() >= 0 && coord.d() < problem_size_.Z &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.P &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.Q &&
|
||||
coord.c() < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dDgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(ConvProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,296 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNDHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv3dFpropActivationTileAccessIteratorAnalytic {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNDHWC;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ConvProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
ConvProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_t_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int filter_c_;
|
||||
|
||||
int offset_n_[ThreadMap::Iterations::kStrided];
|
||||
int offset_z_[ThreadMap::Iterations::kStrided];
|
||||
int offset_p_[ThreadMap::Iterations::kStrided];
|
||||
int offset_q_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dFpropActivationTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
ConvProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_t_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0),
|
||||
filter_c_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_c_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
offset_n_[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q);
|
||||
int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q);
|
||||
|
||||
offset_z_[s] = residual / (problem_size_.P * problem_size_.Q);
|
||||
residual = residual % (problem_size_.P * problem_size_.Q);
|
||||
|
||||
offset_p_[s] = residual / problem_size_.Q;
|
||||
offset_q_[s] = residual % problem_size_.Q;
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
++filter_t_;
|
||||
if (filter_t_ < problem_size_.T) {
|
||||
return;
|
||||
}
|
||||
filter_t_ = 0;
|
||||
|
||||
filter_c_ += Shape::kColumn * problem_size_.split_k_slices;
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the activations tensor X that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
int n = offset_n_[iteration_strided_];
|
||||
int z = offset_z_[iteration_strided_];
|
||||
int p = offset_p_[iteration_strided_];
|
||||
int q = offset_q_[iteration_strided_];
|
||||
|
||||
int t = filter_t_;
|
||||
int r = filter_r_;
|
||||
int s = filter_s_;
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
t = (problem_size_.T - 1 - filter_t_);
|
||||
r = (problem_size_.R - 1 - filter_r_);
|
||||
s = (problem_size_.S - 1 - filter_s_);
|
||||
}
|
||||
|
||||
int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d;
|
||||
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
return TensorCoord(n, d, h, w, filter_c_);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor X
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.d() >= 0 && coord.d() < problem_size_.D &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
AccessType const *ptr = reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dFpropActivationTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(ConvProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,262 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNDHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv3dFpropFilterTileAccessIteratorAnalytic {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNDHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ConvProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
ConvProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_t_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int filter_c_;
|
||||
|
||||
int offset_k_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dFpropFilterTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
ConvProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_t_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0),
|
||||
filter_c_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_c_ = threadblock_offset.row() + thread_coord.contiguous();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * 8 / sizeof_bits<Element>::value;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
}
|
||||
filter_s_ = 0;
|
||||
|
||||
++filter_r_;
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
|
||||
++filter_t_;
|
||||
if (filter_t_ < problem_size_.T) {
|
||||
return;
|
||||
}
|
||||
filter_t_ = 0;
|
||||
|
||||
filter_c_ += Shape::kRow * problem_size_.split_k_slices;
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the filter tensor W that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int k = offset_k_[iteration_strided_];
|
||||
|
||||
return TensorCoord(k, filter_t_, filter_r_, filter_s_, filter_c_);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor W
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K &&
|
||||
coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dFpropFilterTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(ConvProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,281 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNDHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv3dWgradActivationTileAccessIteratorAnalytic {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNDHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv3dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
// Filter postion (t,r,s,c) in contiguous dimension stays constant for each gemm_iteration_k
|
||||
int filter_t_[ThreadMap::Iterations::kContiguous];
|
||||
int filter_r_[ThreadMap::Iterations::kContiguous];
|
||||
int filter_s_[ThreadMap::Iterations::kContiguous];
|
||||
int filter_c_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
int offset_nzpq_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradActivationTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
// initialize t,r,s,c filter position for every contiguous iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
int trsc_offset = threadblock_offset.column() + thread_coord.contiguous()
|
||||
+ c * ThreadMap::Delta::kContiguous;
|
||||
|
||||
filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C);
|
||||
int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C);
|
||||
|
||||
filter_r_[c] = residual / (problem_size_.S * problem_size_.C);
|
||||
residual = residual % (problem_size_.S * problem_size_.C);
|
||||
|
||||
filter_s_[c] = residual / problem_size_.C;
|
||||
filter_c_[c] = residual % problem_size_.C;
|
||||
|
||||
}
|
||||
|
||||
// initialize n, z, p, q offset for every strided iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided()
|
||||
+ s * ThreadMap::Delta::kStrided;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
// moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the activation tensor x that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int t = filter_t_[iteration_contiguous_];
|
||||
int r = filter_r_[iteration_contiguous_];
|
||||
int s = filter_s_[iteration_contiguous_];
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
t = (problem_size_.T - 1 - t);
|
||||
r = (problem_size_.R - 1 - r);
|
||||
s = (problem_size_.S - 1 - s);
|
||||
}
|
||||
|
||||
int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q);
|
||||
int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q);
|
||||
|
||||
int z = residual / (problem_size_.P * problem_size_.Q);
|
||||
residual = residual % (problem_size_.P * problem_size_.Q);
|
||||
|
||||
int p = residual / problem_size_.Q;
|
||||
int q = residual % problem_size_.Q;
|
||||
|
||||
int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d;
|
||||
int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activation tensor x
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.d() >= 0 && coord.d() < problem_size_.D &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradActivationTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv3dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,346 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNDHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv3dWgradActivationTileAccessIteratorOptimized {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNDHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
int RSC; // product of R*S*C
|
||||
unsigned rsc_mul; // precomputed quantities for fast computation of div/% by RSC
|
||||
unsigned rsc_shr; // in device code.
|
||||
|
||||
int SC; // product of S*C
|
||||
unsigned sc_mul; // precomputed quantities for fast computation of div/% by SC
|
||||
unsigned sc_shr; // in device code.
|
||||
|
||||
unsigned c_mul; // precomputed quantities for fast computation of div/% by C
|
||||
unsigned c_shr; // in device code.
|
||||
|
||||
int ZPQ; // product of Z*P*Q
|
||||
unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ
|
||||
unsigned zpq_shr; // in device code.
|
||||
|
||||
int PQ; // product of P*Q
|
||||
unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ
|
||||
unsigned pq_shr; // in device code.
|
||||
|
||||
unsigned q_mul; // precomputed quantities for fast computation of div/% by Q
|
||||
unsigned q_shr; // in device code.
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
// Precompute several quantities for fast modulo arithmetic.
|
||||
RSC = problem_size.R * problem_size.S * problem_size.C;
|
||||
find_divisor(rsc_mul, rsc_shr, RSC);
|
||||
|
||||
SC = problem_size.S * problem_size.C;
|
||||
find_divisor(sc_mul, sc_shr, SC);
|
||||
|
||||
find_divisor(c_mul, c_shr, problem_size.C);
|
||||
|
||||
ZPQ = problem_size.Z * problem_size.P * problem_size.Q;
|
||||
find_divisor(zpq_mul, zpq_shr, ZPQ);
|
||||
|
||||
PQ = problem_size.P * problem_size.Q;
|
||||
find_divisor(pq_mul, pq_shr, PQ);
|
||||
|
||||
find_divisor(q_mul, q_shr, problem_size.Q);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv3dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
// Precomputed effective filter postion (t,r,s) in contiguous dimension stays constant for each gemm_iteration_k
|
||||
// required for nzpq -> ndhw translation
|
||||
int precomputed_filter_t_[ThreadMap::Iterations::kContiguous];
|
||||
int precomputed_filter_r_[ThreadMap::Iterations::kContiguous];
|
||||
int precomputed_filter_s_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
// Channel dimension in contiguous dimension stays constant for each gemm_iteration_k
|
||||
int filter_c_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
int offset_nzpq_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradActivationTileAccessIteratorOptimized(
|
||||
Params const ¶ms,
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
// initialize t,r,s,c filter position for every contiguous iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
int trsc_offset = threadblock_offset.column() + thread_coord.contiguous()
|
||||
+ c * ThreadMap::Delta::kContiguous;
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
//
|
||||
// filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C);
|
||||
// int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C);
|
||||
//
|
||||
// filter_r_[c] = residual / (problem_size_.S * problem_size_.C);
|
||||
// residual = residual % (problem_size_.S * problem_size_.C);
|
||||
//
|
||||
// filter_s_[c] = residual / problem_size_.C;
|
||||
// filter_c_[c] = residual % problem_size_.C;
|
||||
|
||||
int residual;
|
||||
fast_divmod(precomputed_filter_t_[c], residual, trsc_offset, params_.RSC, params_.rsc_mul, params_.rsc_shr);
|
||||
fast_divmod(precomputed_filter_r_[c], residual, residual, params_.SC, params_.sc_mul, params_.sc_shr);
|
||||
fast_divmod(precomputed_filter_s_[c], filter_c_[c], residual, problem_size_.C, params_.c_mul, params_.c_shr);
|
||||
|
||||
int t = precomputed_filter_t_[c];
|
||||
int r = precomputed_filter_r_[c];
|
||||
int s = precomputed_filter_s_[c];
|
||||
|
||||
if (problem_size_.mode == Mode::kConvolution) {
|
||||
t = (problem_size_.T - 1 - t);
|
||||
r = (problem_size_.R - 1 - r);
|
||||
s = (problem_size_.S - 1 - s);
|
||||
}
|
||||
|
||||
// efective t,r,s for every contiguous dimension
|
||||
precomputed_filter_t_[c] = - problem_size_.pad_d + t * problem_size_.dilation_d;
|
||||
precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h;
|
||||
precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w;
|
||||
|
||||
|
||||
}
|
||||
|
||||
// initialize n, z, p, q offset for every strided iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
|
||||
offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided()
|
||||
+ s * ThreadMap::Delta::kStrided;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
|
||||
// moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the activation tensor x that is currently pointed to
|
||||
/// by the iterator.
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
//
|
||||
// int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q);
|
||||
// int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q);
|
||||
//
|
||||
// int z = residual / (problem_size_.P * problem_size_.Q);
|
||||
// residual = residual % (problem_size_.P * problem_size_.Q);
|
||||
//
|
||||
// int p = residual / problem_size_.Q;
|
||||
// int q = residual % problem_size_.Q;
|
||||
|
||||
int residual, n, z, p, q;
|
||||
fast_divmod(n, residual, offset_nzpq_[iteration_strided_], params_.ZPQ, params_.zpq_mul, params_.zpq_shr);
|
||||
fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr);
|
||||
fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr);
|
||||
|
||||
int d = z * problem_size_.stride_d + precomputed_filter_t_[iteration_contiguous_];
|
||||
int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_];;
|
||||
int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_];
|
||||
|
||||
return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activation tensor x
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.d() >= 0 && coord.d() < problem_size_.D &&
|
||||
coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W &&
|
||||
coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradActivationTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv3dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,256 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNDHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv3dWgradOutputGradientTileAccessIteratorAnalytic {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNDHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv3dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_[ThreadMap::Iterations::kContiguous];
|
||||
|
||||
int offset_nzpq_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradOutputGradientTileAccessIteratorAnalytic(
|
||||
Params const ¶ms,
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)) {
|
||||
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
// initialize filter_k for every contiguous iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous()
|
||||
+ c * ThreadMap::Delta::kContiguous;
|
||||
}
|
||||
|
||||
// initialize n, p, q offset for every strided iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_nzpq_[s] = threadblock_offset.column() + thread_coord.strided()
|
||||
+ s * ThreadMap::Delta::kStrided;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next GEMM-K offset (offset_nzpq_) in GEMM-A by a CTA-K tile
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_nzpq_[s] += Shape::kColumn * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the output gradient tensor Dy that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int nzpq = offset_nzpq_[iteration_strided_];
|
||||
|
||||
int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q);
|
||||
int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q);
|
||||
|
||||
int z = residual / (problem_size_.P * problem_size_.Q);
|
||||
residual = residual % (problem_size_.P * problem_size_.Q);
|
||||
|
||||
int p = residual / problem_size_.Q;
|
||||
int q = residual % problem_size_.Q;
|
||||
|
||||
return TensorCoord(n, z, p, q, filter_k_[iteration_contiguous_]);
|
||||
}
|
||||
|
||||
|
||||
/// Returns true if the current coordinate is within the output gradient tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.d() < problem_size_.Z &&
|
||||
coord.h() < problem_size_.P &&
|
||||
coord.w() < problem_size_.Q &&
|
||||
coord.c() < problem_size_.K;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradOutputGradientTileAccessIteratorAnalytic &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv3dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,330 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNDHWC layout of tensors in Global Memory.
|
||||
|
||||
The iterator is specialized for each of the three convolution operators: forward propagation (Fprop),
|
||||
backward data gradient (Dgrad), and backward weight gradient (Wgrad).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename ThreadMap_
|
||||
>
|
||||
class Conv3dWgradOutputGradientTileAccessIteratorOptimized {
|
||||
public:
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::TensorNDHWC;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 3;
|
||||
using ConvProblemSize = typename conv::Conv3dProblemSize;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"WGRAD requires elements of size 8b or greater.");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
Layout layout;
|
||||
|
||||
int NZPQ; // precomputd product of N*Z*P*Q for clearing predicates
|
||||
int ZPQ; // product of Z*P*Q
|
||||
unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ
|
||||
unsigned zpq_shr; // in device code.
|
||||
|
||||
int PQ; // product of P*Q
|
||||
unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ
|
||||
unsigned pq_shr; // in device code.
|
||||
|
||||
unsigned q_mul; // precomputed quantities for fast computation of div/% by Q
|
||||
unsigned q_shr; // in device code.
|
||||
|
||||
LongIndex offset_next_strided; // offset in units of bytes to next nzpq coordinate within tile
|
||||
LongIndex offset_next_contiguous; // offset in units of bytes to next k coordinate within tile
|
||||
LongIndex inc_next_nzpq; // offset in units of bytes to next nzpq position in subsequent tile
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Layout const &layout
|
||||
): layout(layout) {
|
||||
|
||||
// Incremental offsets in unites of bytes (number of elements) * sizeof_bits<Element>::value / 8
|
||||
offset_next_strided = (ThreadMap::Delta::kStrided * layout.stride()[0])
|
||||
* sizeof_bits<Element>::value / 8;
|
||||
|
||||
offset_next_contiguous = (ThreadMap::Delta::kContiguous)
|
||||
* sizeof_bits<Element>::value / 8;
|
||||
|
||||
inc_next_nzpq = (Shape::kColumn * problem_size.split_k_slices * layout.stride()[0])
|
||||
* sizeof_bits<Element>::value / 8;
|
||||
|
||||
// Precompute several quantities for fast modulo arithmetic.
|
||||
NZPQ = problem_size.N * problem_size.Z * problem_size.P * problem_size.Q;
|
||||
ZPQ = problem_size.Z * problem_size.P * problem_size.Q;
|
||||
find_divisor(zpq_mul, zpq_shr, ZPQ);
|
||||
|
||||
PQ = problem_size.P * problem_size.Q;
|
||||
find_divisor(pq_mul, pq_shr, PQ);
|
||||
|
||||
find_divisor(q_mul, q_shr, problem_size.Q);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const ¶ms_;
|
||||
Conv3dProblemSize const &problem_size_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
char const *pointer_;
|
||||
|
||||
uint32_t predicates_;
|
||||
int filter_k_;
|
||||
int offset_nzpq_;
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradOutputGradientTileAccessIteratorOptimized(
|
||||
Params const ¶ms,
|
||||
Conv3dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
predicates_(0),
|
||||
filter_k_(0),
|
||||
offset_nzpq_(0) {
|
||||
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.row() + thread_coord.contiguous();
|
||||
offset_nzpq_ = threadblock_offset.column() + thread_coord.strided();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
|
||||
int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous;
|
||||
int offset_nzpq = offset_nzpq_ + s * ThreadMap::Delta::kStrided;
|
||||
|
||||
bool predicate = valid_(at_(offset_nzpq, filter_k));
|
||||
|
||||
uint32_t pred = (predicate ? 1u : 0);
|
||||
|
||||
int pred_idx = c + s * ThreadMap::Iterations::kContiguous;
|
||||
|
||||
predicates_ |= (pred << pred_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0)
|
||||
pointer_ += (
|
||||
offset_nzpq_ * params.layout.stride()[0] + filter_k_
|
||||
) * sizeof_bits<Element>::value / 8;
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile
|
||||
offset_nzpq_ += Shape::kColumn * problem_size_.split_k_slices;
|
||||
|
||||
// Clear predicates if needed
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
if (offset_nzpq_ + s * ThreadMap::Delta::kStrided >= params_.NZPQ) {
|
||||
uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous);
|
||||
predicates_ = (predicates_ & (~kClearMask));
|
||||
}
|
||||
}
|
||||
pointer_ += params_.inc_next_nzpq;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Returns the coordinate in the output gradient tensor Dy that is (offset_nzpq, k) pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at_(int offset_nzpq, int k) const {
|
||||
|
||||
// The subseqnet fast_divmod() operations are equivalent to the following logical computation:
|
||||
//
|
||||
//
|
||||
// int nzpq = offset_nzpq_;
|
||||
// int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q);
|
||||
// int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q);
|
||||
//
|
||||
// int z = residual / (problem_size_.P * problem_size_.Q);
|
||||
// residual = residual % (problem_size_.P * problem_size_.Q);
|
||||
//
|
||||
// int p = residual / problem_size_.Q;
|
||||
// int q = residual % problem_size_.Q;
|
||||
|
||||
int residual, n, z, p, q;
|
||||
fast_divmod(n, residual, offset_nzpq, params_.ZPQ, params_.zpq_mul, params_.zpq_shr);
|
||||
fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr);
|
||||
fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr);
|
||||
|
||||
return TensorCoord(n, z, p, q, k);
|
||||
}
|
||||
|
||||
/// Returns true if the coord is within the output gradient tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid_(TensorCoord coord) const {
|
||||
|
||||
return coord.n() < problem_size_.N &&
|
||||
coord.c() < problem_size_.K;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Returns true if the current coordinate is within the output gradient tensor Dy
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous;
|
||||
return (predicates_ & (1u << pred_idx));
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
|
||||
return reinterpret_cast<AccessType const *>(
|
||||
pointer_ +
|
||||
iteration_strided_ * params_.offset_next_strided +
|
||||
iteration_contiguous_ * params_.offset_next_contiguous
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
Conv3dWgradOutputGradientTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv3dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % (128/sizeof_bits<Element>::value)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
480
include/cutlass/conv/threadblock/implicit_gemm_multistage.h
Normal file
480
include/cutlass/conv/threadblock/implicit_gemm_multistage.h
Normal file
@ -0,0 +1,480 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class ImplicitGemmMultistage :
|
||||
public gemm::threadblock::MmaBase<Shape_, Policy_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::MmaBase<Shape_, Policy_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
|
||||
using ElementC = typename Policy::Operator::ElementC;
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA =
|
||||
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
ImplicitGemmMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(
|
||||
IteratorA &iterator_A, IteratorB &iterator_B,
|
||||
int group_start_A = 0, int group_start_B = 0) {
|
||||
|
||||
iterator_A.set_iteration_index(group_start_A);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
||||
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B);
|
||||
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum,
|
||||
///< Imaginary strides used for planar-complex only - ignored here
|
||||
int64_t imag_stride_A = 0,
|
||||
int64_t imag_stride_B = 0) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.advance();
|
||||
iterator_B.advance();
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A[2];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B[2];
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance(iterator_A, iterator_B);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B[warp_mma_k % 2]);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
group_start_iteration_A = 0;
|
||||
group_start_iteration_B = 0;
|
||||
} else {
|
||||
group_start_iteration_A =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A,
|
||||
group_start_iteration_B);
|
||||
|
||||
warp_mma(
|
||||
accum,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations)
|
||||
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.advance();
|
||||
iterator_B.advance();
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK *
|
||||
Base::kWarpGemmIterations,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
313
include/cutlass/conv/threadblock/implicit_gemm_pipelined.h
Normal file
313
include/cutlass/conv/threadblock/implicit_gemm_pipelined.h
Normal file
@ -0,0 +1,313 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Transformation applied to A operand
|
||||
typename TransformA_ = NumericArrayConverter<
|
||||
typename SmemIteratorA_::Element,
|
||||
typename IteratorA_::Element,
|
||||
IteratorA_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to A operand
|
||||
typename TransformB_ = NumericArrayConverter<
|
||||
typename SmemIteratorB_::Element,
|
||||
typename IteratorB_::Element,
|
||||
IteratorB_::Fragment::kElements>,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class ImplicitGemmPipelined : public gemm::threadblock::MmaBase<Shape_, Policy_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::MmaBase<Shape_, Policy_, 2>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
|
||||
protected:
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
ImplicitGemmPipelined(
|
||||
typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC &accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
FragmentC const &src_accum, ///< source accumulator tile
|
||||
TransformA transform_A = TransformA(), ///< transformation applied to A fragment
|
||||
TransformB transform_B = TransformB()) { ///< transformation applied to B fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transform_A(tb_frag_A));
|
||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transform_A(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
warp_mma(accum, warp_frag_A[warp_mma_k % 2],
|
||||
warp_frag_B[warp_mma_k % 2], accum);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -38,6 +38,9 @@
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -156,13 +159,23 @@ namespace gemm {
|
||||
template <int M, int N, int K>
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, GemmShape<M,N,K> const &gemm_shape) {
|
||||
out << "cutlass::GemmShape::(kM, kN, kK) {"
|
||||
out << "cutlass::gemm::GemmShape::(kM, kN, kK) {"
|
||||
<< cutlass::gemm::GemmShape<M,N,K>::kM <<","
|
||||
<< cutlass::gemm::GemmShape<M,N,K>::kN <<","
|
||||
<< cutlass::gemm::GemmShape<M,N,K>::kK << "}";
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Default printing to ostream for GemmCoord
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, GemmCoord const &gemm_coord) {
|
||||
out << "cutlass::gemm::GemmCoord:: {"
|
||||
<< gemm_coord.m() <<","
|
||||
<< gemm_coord.n() <<","
|
||||
<< gemm_coord.k() << "}";
|
||||
return out;
|
||||
}
|
||||
|
||||
} //namespace gemm
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -185,5 +198,44 @@ std::ostream & operator<<(std::ostream &out, PitchLinearShape<Contiguous, Stride
|
||||
} //namespace layout
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// stream operators for cutlass::conv namespace //
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
namespace conv {
|
||||
/// Default printing to ostream for Conv2dProblemSize
|
||||
inline
|
||||
std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) {
|
||||
out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl
|
||||
<< "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl
|
||||
<< "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl
|
||||
<< "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl
|
||||
<< "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl
|
||||
<< "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl
|
||||
<< "split_k_slices: (" << problem.split_k_slices << ")" << std::endl
|
||||
<< "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
/// Default printing to ostream for Conv3dProblemSize
|
||||
inline
|
||||
std::ostream& operator<<(std::ostream& out, Conv3dProblemSize const& problem) {
|
||||
out << "NDHWC: (" << problem.N << ", " << problem.D << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl
|
||||
<< "KTRSC: (" << problem.K << ", " << problem.T << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl
|
||||
<< "NZPQK: (" << problem.N << ", " << problem.Z << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl
|
||||
<< "pad_d, pad_h, pad_w: (" << problem.pad_d << ", " << problem.pad_h << ", " << problem.pad_w << ")" << std::endl
|
||||
<< "stride_d, stride_h, stride_w: (" << problem.stride_d << ", " << problem.stride_h << ", " << problem.stride_w << ")" << std::endl
|
||||
<< "dilation_d, dilation_h, dilation_w: (" << problem.dilation_d << ", " << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl
|
||||
<< "split_k_slices: (" << problem.split_k_slices << ") " << std::endl
|
||||
<< "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -145,7 +145,7 @@ public:
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
@ -133,7 +133,7 @@ public:
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
@ -319,7 +319,7 @@ public:
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
@ -354,7 +354,7 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = static_cast<int>(intermediate[i]);
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
|
||||
// Convert to destination numeric type
|
||||
@ -385,7 +385,7 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = static_cast<int>(intermediate[i]);
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
|
||||
// Convert to destination numeric type
|
||||
@ -495,7 +495,7 @@ class FastLinearCombinationClamp {
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
@ -134,7 +134,7 @@ public:
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
@ -28,6 +28,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/half.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
@ -77,7 +78,6 @@ public:
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
ElementCompute const *threshold_ptr; ///< pointer to threshold scalar - if not null, loads from memory
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -88,15 +88,14 @@ public:
|
||||
beta(ElementCompute(0)),
|
||||
threshold(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr),
|
||||
threshold_ptr(nullptr) { }
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr), threshold_ptr(nullptr) {
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
@ -104,8 +103,8 @@ public:
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr,
|
||||
ElementCompute const *threshold_ptr = nullptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), threshold_ptr(threshold_ptr) {
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
@ -128,7 +127,7 @@ public:
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
threshold_ = (params.threshold_ptr ? *params.threshold_ptr : params.threshold);
|
||||
threshold_ = params.threshold;
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
@ -139,10 +138,16 @@ public:
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
// set to NaN to make ReLU no-op for all except last k partitions
|
||||
int64_t allones = -1;
|
||||
threshold_ = reinterpret_cast<ElementCompute const &>(allones);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
@ -205,7 +210,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conditional guards to enable partial specialization for packed integers
|
||||
@ -245,7 +249,6 @@ public:
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
ElementCompute const *threshold_ptr; ///< pointer to threshold scalar - if not null, loads from memory
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -256,15 +259,14 @@ public:
|
||||
beta(ElementCompute(0)),
|
||||
threshold(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr),
|
||||
threshold_ptr(nullptr) { }
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr), threshold_ptr(nullptr) {
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
@ -272,8 +274,8 @@ public:
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr,
|
||||
ElementCompute const *threshold_ptr = nullptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), threshold_ptr(threshold_ptr) {
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
@ -296,7 +298,7 @@ public:
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
threshold_ = (params.threshold_ptr ? *params.threshold_ptr : params.threshold);
|
||||
threshold_ = params.threshold;
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
@ -307,10 +309,16 @@ public:
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
// set to NaN to make ReLU no-op for all except last k partitions
|
||||
int64_t allones = -1;
|
||||
threshold_ = reinterpret_cast<ElementCompute const &>(allones);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
@ -331,26 +339,41 @@ public:
|
||||
|
||||
multiplies<ComputeFragment> mul_add_source;
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
ReLu<FragmentAccumulator> relu;
|
||||
ReLu<ComputeFragment> relu;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(threshold_, intermediate);
|
||||
|
||||
if (platform::is_same<ElementOutput, int32_t>::value ||
|
||||
platform::is_same<ElementOutput, uint32_t>::value ||
|
||||
platform::is_same<ElementOutput, int16_t>::value ||
|
||||
platform::is_same<ElementOutput, uint16_t>::value ||
|
||||
platform::is_same<ElementOutput, int8_t>::value ||
|
||||
platform::is_same<ElementOutput, uint8_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint1b_t>::value) {
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = static_cast<int>(intermediate[i]);
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
|
||||
// Compute threshold optionally
|
||||
scaled_accumulator = relu(threshold_, scaled_accumulator);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
||||
destination_converter;
|
||||
|
||||
return destination_converter(scaled_accumulator);
|
||||
} else {
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
||||
destination_converter;
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
@ -367,25 +390,48 @@ public:
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_accumulator;
|
||||
ReLu<FragmentAccumulator> relu;
|
||||
ReLu<ComputeFragment> relu;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(threshold_, intermediate);
|
||||
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = static_cast<int>(intermediate[i]);
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
|
||||
// Compute threshold optionally
|
||||
scaled_accumulator = relu(threshold_, scaled_accumulator);
|
||||
if (platform::is_same<ElementOutput, int32_t>::value ||
|
||||
platform::is_same<ElementOutput, uint32_t>::value ||
|
||||
platform::is_same<ElementOutput, int16_t>::value ||
|
||||
platform::is_same<ElementOutput, uint16_t>::value ||
|
||||
platform::is_same<ElementOutput, int8_t>::value ||
|
||||
platform::is_same<ElementOutput, uint8_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint1b_t>::value) {
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = __float2int_rn(intermediate[i]);
|
||||
}
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round>
|
||||
destination_converter;
|
||||
|
||||
return destination_converter(scaled_accumulator);
|
||||
} else {
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
||||
destination_converter;
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -398,4 +444,3 @@ public:
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -133,7 +133,7 @@ public:
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
@ -367,6 +367,52 @@ struct DefaultInterleavedEpilogueTensorOp {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps which uses
|
||||
/// intereleaved output layout. For this case, shared memory is not needed.
|
||||
template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
|
||||
typename OutputOp_, int ElementsPerAccess, int InterleavedK,
|
||||
bool IsBetaZero = false, bool isSplitK = false>
|
||||
struct DefaultInterleavedConvEpilogue {
|
||||
using Shape = Shape_;
|
||||
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
||||
|
||||
//
|
||||
// Thread map
|
||||
//
|
||||
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedConvThreadMapTensorOp<
|
||||
Shape, typename WarpMmaTensorOp::Shape, kPartitionsK, ElementOutput,
|
||||
kElementsPerAccess, InterleavedK>::Type;
|
||||
|
||||
using OutputTileIterator =
|
||||
cutlass::epilogue::threadblock::InterleavedConvPredicatedTileIterator<
|
||||
OutputTileThreadMap, ElementOutput, InterleavedK>;
|
||||
|
||||
using AccumulatorFragmentIterator =
|
||||
cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
// can reuse the gemm version here to do element selection
|
||||
layout::ColumnMajorInterleaved<InterleavedK>>;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
using Epilogue = cutlass::epilogue::threadblock::InterleavedEpilogue<
|
||||
Shape, WarpMmaTensorOp, kPartitionsK, OutputTileIterator,
|
||||
AccumulatorFragmentIterator, OutputOp, InterleavedK, IsBetaZero>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
@ -144,6 +144,55 @@ struct DefaultInterleavedThreadMapTensorOp {
|
||||
Detail::kThreads, kElementsPerAccess, sizeof_bits<Element>::value>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines the optimal thread map for TensorOp accumulator layouts
|
||||
template <typename ThreadblockShape_, typename WarpShape_, int PartitionsK,
|
||||
typename Element_, int ElementsPerAccess, int InterleavedK>
|
||||
struct DefaultInterleavedConvThreadMapTensorOp {
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using Element = Element_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kInterleavedK = InterleavedK;
|
||||
|
||||
//
|
||||
// Definitions
|
||||
//
|
||||
|
||||
struct Detail {
|
||||
/// Tensor Operations fundamentally perform operations on 8 rows
|
||||
static int const kTensorOpRows = 8;
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
|
||||
!(ThreadblockShape::kN % WarpShape::kN),
|
||||
"Divisibility");
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount =
|
||||
gemm::GemmShape<ThreadblockShape::kM / WarpShape::kM,
|
||||
ThreadblockShape::kN / WarpShape::kN, kPartitionsK>;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
};
|
||||
|
||||
//
|
||||
// ThreadMap
|
||||
//
|
||||
|
||||
/// ThreadMap to be used by epilogue::MaskedTileIterator satisfying concept
|
||||
/// InterleavedOutputTileThreadMap
|
||||
using Type = InterleavedConvOutputTileThreadMap<
|
||||
MatrixShape<Detail::WarpCount::kM, Detail::WarpCount::kN>,
|
||||
MatrixShape<WarpShape::kM / Detail::kTensorOpRows,
|
||||
WarpShape::kN / InterleavedK>,
|
||||
Detail::kThreads, kElementsPerAccess, sizeof_bits<Element>::value>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
|
||||
@ -0,0 +1,92 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
template<
|
||||
typename TensorLayout_, ///! The original output tensor layout
|
||||
typename OutputIteratorLayout_, ///! Layout used by epilogue output iterator
|
||||
typename TensorRef_, ///! Input tensor to epilogue output iterator
|
||||
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
|
||||
typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem
|
||||
>
|
||||
struct ConvOutputIteratorParameter {
|
||||
|
||||
using TensorLayout = TensorLayout_;
|
||||
using OutputIteratorLayout = OutputIteratorLayout_;
|
||||
using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord;
|
||||
using TensorRef = TensorRef_;
|
||||
static conv::Operator const kConvolutionalOperator = ConvOperator;
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
/// Wgrad stride idx for implicit gemm algorithm
|
||||
// Conv2d row-major matrix (KxRSC)
|
||||
// Conv3d row-major matrix (KxTRSC)
|
||||
static int const kWgradStrideIdx =
|
||||
platform::is_same<TensorLayout, layout::TensorNHWC>::value ? 2 : 3;
|
||||
|
||||
/// This chooses the appropriate stride element of the C tensor.
|
||||
static int const kTensorStrideIdx =
|
||||
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradStrideIdx : 0);
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static OutputIteratorLayout layout(const TensorRef & ref) {
|
||||
return ref.stride(kTensorStrideIdx);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static OutputTensorCoord extent(ConvProblemSize problem_size) {
|
||||
return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <
|
||||
int InterleavedK,
|
||||
typename TensorRef_,
|
||||
conv::Operator ConvOperator,
|
||||
typename ConvProblemSize_
|
||||
>
|
||||
struct ConvOutputIteratorParameter<
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
TensorRef_,
|
||||
ConvOperator,
|
||||
ConvProblemSize_>
|
||||
{
|
||||
|
||||
using TensorLayout = typename layout::TensorNCxHWx<InterleavedK>;
|
||||
using OutputIteratorLayout = typename layout::TensorNCxHWx<InterleavedK>;
|
||||
using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord;
|
||||
using TensorRef = TensorRef_;
|
||||
static conv::Operator const kConvolutionalOperator = ConvOperator;
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static OutputIteratorLayout layout(const TensorRef & ref) {
|
||||
return ref.stride();
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static OutputTensorCoord extent(ConvProblemSize problem_size) {
|
||||
return problem_size.output_extent();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
@ -488,6 +488,68 @@ struct InterleavedOutputTileThreadMap {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template metaprogram for partitioning a 4D interleaved layout across warps
|
||||
/// to achieve several performance objectives:
|
||||
///
|
||||
/// - coalesced memory accesses in units of 64 Byte lines
|
||||
/// - minimal address arithmetic
|
||||
/// - minimal predicate calculations
|
||||
///
|
||||
template <typename WarpCount_, typename Iterations_, int Threads,
|
||||
int ElementsPerAccess, int ElementSize>
|
||||
struct InterleavedConvOutputTileThreadMap {
|
||||
using WarpCount = WarpCount_;
|
||||
|
||||
static int const kWarpSize = 32;
|
||||
static int const kThreads = Threads;
|
||||
static int const kWarpCount = kThreads / kWarpSize;
|
||||
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kElementSize = ElementSize;
|
||||
|
||||
//
|
||||
// Metaprogram computation
|
||||
//
|
||||
|
||||
struct Detail {};
|
||||
|
||||
//
|
||||
// Output
|
||||
//
|
||||
|
||||
using Iterations = Iterations_;
|
||||
|
||||
using Delta = MatrixShape<kWarpSize / 4, 4 * kElementsPerAccess>;
|
||||
|
||||
/// Initial offset function
|
||||
CUTLASS_HOST_DEVICE
|
||||
static MatrixCoord initial_offset(int thread_idx) {
|
||||
int warp_idx = thread_idx / kWarpSize;
|
||||
int lane_idx = thread_idx % kWarpSize;
|
||||
|
||||
// Compute warp location
|
||||
MatrixCoord warp_footprint{
|
||||
Delta::kRow * Iterations::kRow,
|
||||
Delta::kColumn * Iterations::kColumn,
|
||||
};
|
||||
|
||||
MatrixCoord warp_offset{warp_idx % WarpCount::kRow,
|
||||
warp_idx / WarpCount::kRow};
|
||||
|
||||
// Compute per-lane offset
|
||||
MatrixCoord thread_offset_in_warp{lane_idx / 4,
|
||||
(lane_idx % 4) * kElementsPerAccess};
|
||||
|
||||
MatrixCoord thread_offset_in_threadblock_tile =
|
||||
warp_footprint * warp_offset + thread_offset_in_warp;
|
||||
|
||||
return thread_offset_in_threadblock_tile;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
|
||||
@ -43,6 +43,7 @@
|
||||
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -102,68 +103,20 @@ public:
|
||||
// Parameters struct
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
LongIndex stride; ///< stride in bytes between rows
|
||||
|
||||
LongIndex increment_row; ///< increment quantity (in bytes) to advance when moving between rows
|
||||
LongIndex increment_group; ///< increment quantity (in bytes) to advance when moving to the next group
|
||||
LongIndex increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster
|
||||
|
||||
LongIndex advance_row; ///< amount to add to move to the next 'row' position
|
||||
LongIndex advance_group; ///< amount to add to move to the next 'group' position
|
||||
LongIndex advance_cluster; ///< amount to add to move to the next 'cluster' position
|
||||
LongIndex advance_tile; ///< amount to add to move to the next 'tile'
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
/// Uses a non-template class
|
||||
struct Params : PredicatedTileIteratorParams {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Status initialize(Index stride_) {
|
||||
|
||||
stride = LongIndex(stride_);
|
||||
|
||||
increment_row = stride * ThreadMap::Delta::kRow;
|
||||
|
||||
increment_group = stride * ThreadMap::Delta::kGroup
|
||||
- stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);
|
||||
|
||||
increment_cluster = stride * ThreadMap::Delta::kCluster
|
||||
- stride * ThreadMap::Delta::kGroup * (ThreadMap::Iterations::kGroup - 1)
|
||||
- stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);
|
||||
|
||||
advance_row = stride * ThreadMap::Shape::kRow;
|
||||
|
||||
advance_group = stride * (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
|
||||
advance_cluster =
|
||||
stride *
|
||||
ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;;
|
||||
|
||||
advance_tile =
|
||||
stride *
|
||||
ThreadMap::Shape::kGroup *
|
||||
ThreadMap::Shape::kRow *
|
||||
ThreadMap::Shape::kCluster *
|
||||
ThreadMap::Shape::kTile;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {
|
||||
initialize(0);
|
||||
}
|
||||
Params(Layout const &layout):
|
||||
PredicatedTileIteratorParams(
|
||||
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
|
||||
make_OutputTileThreadMapDesc<ThreadMap>()
|
||||
)
|
||||
{
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const &layout) {
|
||||
|
||||
initialize(layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess);
|
||||
}
|
||||
};
|
||||
|
||||
@ -207,7 +160,7 @@ private:
|
||||
//
|
||||
|
||||
/// Parameters structure containing reference and precomputed state.
|
||||
Params params_;
|
||||
PredicatedTileIteratorParams params_;
|
||||
|
||||
/// Byte-level pointer
|
||||
uint8_t *byte_pointer_;
|
||||
@ -239,12 +192,13 @@ public:
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
PredicatedTileIterator(
|
||||
Params const & params,
|
||||
PredicatedTileIteratorParams const & params,
|
||||
Element *pointer,
|
||||
TensorCoord extent,
|
||||
int thread_idx,
|
||||
TensorCoord threadblock_offset = TensorCoord()
|
||||
): params_(params)
|
||||
):
|
||||
params_(params)
|
||||
{
|
||||
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
|
||||
@ -745,6 +699,309 @@ public:
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator used to load output tile from shared memory in epilogue.
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator | InterleavedMaskedTileIterator | ForwardTileIterator
|
||||
///
|
||||
template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
typename Element_, ///< Element data type
|
||||
int InterleavedN ///< Number of Interleaved N
|
||||
>
|
||||
class InterleavedConvPredicatedTileIterator {
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
|
||||
using Element = Element_;
|
||||
|
||||
using Layout = layout::TensorNCxHWx<InterleavedN>;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
using TensorCoord = Tensor4DCoord;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
static int const kThreads = ThreadMap::kThreads;
|
||||
static int const kIterations = ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Fragment object
|
||||
using Fragment = Array<Element, ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Memory access size
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
|
||||
//
|
||||
// Parameters struct
|
||||
//
|
||||
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
LongIndex stride_col; ///< stride in bytes between columns
|
||||
LongIndex stride_row; ///< stride in bytes between rows
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Status initialize(typename Layout::Stride stride_) {
|
||||
stride_col = stride_[1];
|
||||
stride_row = stride_[2];
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {
|
||||
initialize(cutlass::make_Coord(0, 0, 0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const &layout) {
|
||||
|
||||
initialize(layout.stride());
|
||||
}
|
||||
};
|
||||
|
||||
/// Mask object
|
||||
struct Mask {
|
||||
static int const kCount =
|
||||
(ThreadMap::Iterations::kRow < 8) ? 8 : ThreadMap::Iterations::kRow;
|
||||
|
||||
/// Predicate state
|
||||
bool predicates[kCount];
|
||||
|
||||
//
|
||||
// Mask
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Mask() {
|
||||
enable();
|
||||
}
|
||||
|
||||
///< Efficiently disables all accesses guarded by mask
|
||||
CUTLASS_HOST_DEVICE void clear() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
predicates[i] = false;
|
||||
}
|
||||
}
|
||||
|
||||
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void enable() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
predicates[i] = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters structure containing reference and precomputed state.
|
||||
Params params_;
|
||||
|
||||
/// Byte-level pointer
|
||||
uint8_t *byte_pointer_;
|
||||
|
||||
/// Array of boolean values to contain steady-state predicates
|
||||
Mask mask_;
|
||||
|
||||
/// Extent of the matrix tile in columns
|
||||
Index extent_col_;
|
||||
|
||||
/// Extent of the matrix tile in rows
|
||||
Index extent_row_;
|
||||
|
||||
/// Extent of the matrix tile in pq
|
||||
Index extent_pq_;
|
||||
|
||||
/// A thread's starting row position (assuming steady-state predicates have
|
||||
/// been computed)
|
||||
Index thread_start_row_;
|
||||
|
||||
/// A thread's starting column position (assuming steady-state predicates have
|
||||
/// been computed)
|
||||
Index thread_start_col_;
|
||||
|
||||
/// Internal iteration counter
|
||||
LongIndex iteration_row_;
|
||||
LongIndex iteration_col_;
|
||||
|
||||
uint32_t pq_mul_;
|
||||
|
||||
uint32_t pq_shr_;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
InterleavedConvPredicatedTileIterator(
|
||||
Params const & params,
|
||||
Element *pointer,
|
||||
TensorCoord extent,
|
||||
int thread_idx,
|
||||
MatrixCoord threadblock_offset
|
||||
):
|
||||
params_(params) {
|
||||
MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
|
||||
|
||||
extent_col_ = extent.c();
|
||||
extent_pq_ = extent.h() * extent.w();
|
||||
extent_row_ = extent.n() * extent_pq_;
|
||||
|
||||
find_divisor(pq_mul_, pq_shr_, extent_pq_);
|
||||
|
||||
thread_start_row_ = thread_offset.row();
|
||||
thread_start_col_ = thread_offset.column();
|
||||
|
||||
// Initialize predicates
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int r = 0; r < ThreadMap::Iterations::kRow; ++r) {
|
||||
mask_.predicates[r] =
|
||||
((thread_offset.row() + ThreadMap::Delta::kRow * r) < extent_row_);
|
||||
}
|
||||
|
||||
// Initialize pointer
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
((thread_start_col_ / InterleavedN) * params_.stride_col +
|
||||
(thread_start_col_ % InterleavedN)) *
|
||||
sizeof_bits<Element>::value / 8;
|
||||
|
||||
// Initialize internal state counter
|
||||
iteration_row_ = iteration_col_ = 0;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
|
||||
int col_offset = iteration_col_ * ThreadMap::Delta::kColumn;
|
||||
bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);
|
||||
bool guard = col_guard && mask_.predicates[iteration_row_];
|
||||
|
||||
int n, pq_rem;
|
||||
|
||||
fast_divmod(n, pq_rem,
|
||||
thread_start_row_ + iteration_row_ * ThreadMap::Delta::kRow,
|
||||
extent_pq_, pq_mul_, pq_shr_);
|
||||
|
||||
uint8_t *byte_pointer =
|
||||
byte_pointer_ + (n * params_.stride_row + pq_rem * InterleavedN) *
|
||||
sizeof_bits<Element>::value / 8;
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
AccessType const *memory_pointer =
|
||||
reinterpret_cast<AccessType const *>(byte_pointer);
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
*frag_ptr,
|
||||
(void *)memory_pointer,
|
||||
guard);
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) {
|
||||
|
||||
int col_offset = iteration_col_ * ThreadMap::Delta::kColumn;
|
||||
bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);
|
||||
bool guard = col_guard && mask_.predicates[iteration_row_];
|
||||
|
||||
int n, pq_rem;
|
||||
|
||||
fast_divmod(n, pq_rem,
|
||||
thread_start_row_ + iteration_row_ * ThreadMap::Delta::kRow,
|
||||
extent_pq_, pq_mul_, pq_shr_);
|
||||
|
||||
uint8_t *byte_pointer =
|
||||
byte_pointer_ + (n * params_.stride_row + pq_rem * InterleavedN) *
|
||||
sizeof_bits<Element>::value / 8;
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
|
||||
|
||||
if (guard) {
|
||||
*memory_pointer = *frag_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(int iteration) {
|
||||
iteration_row_ = iteration % ThreadMap::Iterations::kRow;
|
||||
iteration_col_ = iteration / ThreadMap::Iterations::kRow;
|
||||
}
|
||||
|
||||
/// Advances to the next position to load or store
|
||||
CUTLASS_HOST_DEVICE
|
||||
InterleavedConvPredicatedTileIterator &operator++() {
|
||||
|
||||
++iteration_row_;
|
||||
|
||||
if (iteration_row_ == ThreadMap::Iterations::kRow) {
|
||||
|
||||
iteration_row_ = 0;
|
||||
++iteration_col_;
|
||||
byte_pointer_ += params_.stride_col;
|
||||
|
||||
if (iteration_col_ == ThreadMap::Iterations::kColumn) {
|
||||
iteration_col_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< Efficiently disables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void clear_mask() {
|
||||
mask_.clear();
|
||||
}
|
||||
|
||||
///< Efficiently enables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void enable_mask() {
|
||||
mask_.enable();
|
||||
}
|
||||
|
||||
///< Sets the mask
|
||||
CUTLASS_DEVICE void get_mask(Mask &mask) {
|
||||
return mask_;
|
||||
}
|
||||
|
||||
///< Sets the mask
|
||||
CUTLASS_DEVICE void set_mask(Mask const &mask) {
|
||||
mask_ = mask;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
|
||||
@ -0,0 +1,227 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct OutputTileShapeDesc {
|
||||
|
||||
int column;
|
||||
int row;
|
||||
int group;
|
||||
int cluster;
|
||||
int tile;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
OutputTileShapeDesc(): column(0), row(0), group(0), cluster(0), tile(0) { }
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
OutputTileShapeDesc(
|
||||
int column_,
|
||||
int row_,
|
||||
int group_,
|
||||
int cluster_,
|
||||
int tile_
|
||||
):
|
||||
column(column_),
|
||||
row(row_),
|
||||
group(group_),
|
||||
cluster(cluster_),
|
||||
tile(tile_) { }
|
||||
|
||||
/// Total number of points in the 5D space
|
||||
CUTLASS_HOST_DEVICE
|
||||
int count() const {
|
||||
return column * row * group * cluster * tile;
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper template to construct an OutputTileShapeDesc from a OutputTileShape template.
|
||||
template <typename Shape>
|
||||
CUTLASS_HOST_DEVICE
|
||||
OutputTileShapeDesc make_OutputTileShapeDesc() {
|
||||
return OutputTileShapeDesc(
|
||||
Shape::kColumn,
|
||||
Shape::kRow,
|
||||
Shape::kGroup,
|
||||
Shape::kCluster,
|
||||
Shape::kTile
|
||||
);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Thread map description
|
||||
struct OutputTileThreadMapDesc {
|
||||
|
||||
int threads;
|
||||
int elements_per_access;
|
||||
OutputTileShapeDesc shape;
|
||||
OutputTileShapeDesc iterations;
|
||||
OutputTileShapeDesc delta;
|
||||
OutputTileShapeDesc count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
OutputTileThreadMapDesc() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
OutputTileThreadMapDesc(
|
||||
int threads_,
|
||||
int elements_per_access_,
|
||||
OutputTileShapeDesc shape_,
|
||||
OutputTileShapeDesc iterations_,
|
||||
OutputTileShapeDesc delta_,
|
||||
OutputTileShapeDesc count_
|
||||
):
|
||||
threads(threads_),
|
||||
elements_per_access(elements_per_access_),
|
||||
shape(shape_),
|
||||
iterations(iterations_),
|
||||
delta(delta_),
|
||||
count(count_) { }
|
||||
};
|
||||
|
||||
/// Helper template to construct an OutputTileShapeDesc from a OutputTileThreadMap template.
|
||||
template <typename ThreadMap>
|
||||
CUTLASS_HOST_DEVICE
|
||||
OutputTileThreadMapDesc make_OutputTileThreadMapDesc() {
|
||||
return OutputTileThreadMapDesc(
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::kElementsPerAccess,
|
||||
make_OutputTileShapeDesc<typename ThreadMap::Shape>(),
|
||||
make_OutputTileShapeDesc<typename ThreadMap::Iterations>(),
|
||||
make_OutputTileShapeDesc<typename ThreadMap::Delta>(),
|
||||
make_OutputTileShapeDesc<typename ThreadMap::Count>()
|
||||
);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Parameters struct
|
||||
//
|
||||
|
||||
struct PredicatedTileIteratorParams {
|
||||
|
||||
using Index = int32_t;
|
||||
using LongIndex = int64_t;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
LongIndex stride; ///< stride in bytes between rows
|
||||
|
||||
LongIndex increment_row; ///< increment quantity (in bytes) to advance when moving between rows
|
||||
LongIndex increment_group; ///< increment quantity (in bytes) to advance when moving to the next group
|
||||
LongIndex increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster
|
||||
|
||||
LongIndex advance_row; ///< amount to add to move to the next 'row' position
|
||||
LongIndex advance_group; ///< amount to add to move to the next 'group' position
|
||||
LongIndex advance_cluster; ///< amount to add to move to the next 'cluster' position
|
||||
LongIndex advance_tile; ///< amount to add to move to the next 'tile'
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Status initialize(Index stride_, OutputTileThreadMapDesc thread_map) {
|
||||
|
||||
stride = LongIndex(stride_);
|
||||
|
||||
increment_row = stride * thread_map.delta.row;
|
||||
|
||||
increment_group = stride * thread_map.delta.group
|
||||
- stride * thread_map.delta.row * (thread_map.iterations.row - 1);
|
||||
|
||||
increment_cluster = stride * thread_map.delta.cluster
|
||||
- stride * thread_map.delta.group * (thread_map.iterations.group - 1)
|
||||
- stride * thread_map.delta.row * (thread_map.iterations.row - 1);
|
||||
|
||||
advance_row = stride * thread_map.shape.row;
|
||||
|
||||
advance_group =
|
||||
stride *
|
||||
(thread_map.shape.group - 1) * thread_map.shape.row * thread_map.count.row;
|
||||
|
||||
advance_cluster =
|
||||
stride *
|
||||
thread_map.count.group *
|
||||
thread_map.shape.group *
|
||||
thread_map.count.row *
|
||||
thread_map.shape.row;
|
||||
|
||||
advance_tile =
|
||||
stride *
|
||||
thread_map.shape.group *
|
||||
thread_map.shape.row *
|
||||
thread_map.shape.cluster *
|
||||
thread_map.shape.tile;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIteratorParams() {
|
||||
initialize(0, OutputTileThreadMapDesc());
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIteratorParams(Index stride, OutputTileThreadMapDesc thread_map) {
|
||||
|
||||
initialize(stride, thread_map);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -37,7 +37,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__clang__)
|
||||
#if !(defined(__clang__) && defined(__CUDA__))
|
||||
|
||||
#include "cutlass/wmma_array.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
@ -152,5 +152,7 @@ public:
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#else
|
||||
#error (defined(__clang__) && defined(__CUDA__))
|
||||
#endif // !defined(__clang__)
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__clang__)
|
||||
#if !(defined(__clang__) && defined(__CUDA__))
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/wmma_array.h"
|
||||
|
||||
@ -200,7 +200,7 @@ void fast_divmod(int& quo, int& rem, int src, int div, unsigned int mul, unsigne
|
||||
// Use IMUL.HI if div != 1, else simply copy the source.
|
||||
quo = (div != 1) ? __umulhi(src, mul) >> shr : src;
|
||||
#else
|
||||
quo = int((div != 1) ? int(src * mul) >> shr : src);
|
||||
quo = int((div != 1) ? int(((int64_t)src * mul) >> 32) >> shr : src);
|
||||
#endif
|
||||
|
||||
// The remainder.
|
||||
@ -215,7 +215,7 @@ void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul,
|
||||
// Use IMUL.HI if div != 1, else simply copy the source.
|
||||
quo = (div != 1) ? __umulhi(src, mul) >> shr : src;
|
||||
#else
|
||||
quo = int((div != 1) ? (src * mul) >> shr : src);
|
||||
quo = int((div != 1) ? ((src * mul) >> 32) >> shr : src);
|
||||
#endif
|
||||
// The remainder.
|
||||
rem = src - (quo * div);
|
||||
|
||||
@ -161,6 +161,42 @@ struct negate {
|
||||
}
|
||||
};
|
||||
|
||||
/// Greater equal
|
||||
template <typename T>
|
||||
struct greater_equal {
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator()(T const &lhs, T const &rhs) const {
|
||||
return (lhs >= rhs);
|
||||
}
|
||||
};
|
||||
|
||||
/// Greater
|
||||
template <typename T>
|
||||
struct greater {
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator()(T const &lhs, T const &rhs) const {
|
||||
return (lhs > rhs);
|
||||
}
|
||||
};
|
||||
|
||||
/// Less equal
|
||||
template <typename T>
|
||||
struct less_equal {
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator()(T const &lhs, T const &rhs) const {
|
||||
return (lhs <= rhs);
|
||||
}
|
||||
};
|
||||
|
||||
/// Less
|
||||
template <typename T>
|
||||
struct less {
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator()(T const &lhs, T const &rhs) const {
|
||||
return (lhs < rhs);
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused multiply-add
|
||||
template <typename A, typename B = A, typename C = A>
|
||||
struct multiply_add {
|
||||
@ -189,6 +225,40 @@ struct xor_add {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct conjugate {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &a) const {
|
||||
return a;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
struct conjugate<complex<T>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex<T> operator()(complex<T> const &a) const {
|
||||
return conj(a);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct conjugate<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &a) const {
|
||||
|
||||
conjugate<T> conj_op;
|
||||
|
||||
Array<T, N> ca;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
ca[i] = conj_op(a[i]);
|
||||
}
|
||||
return ca;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Partial specialization for complex<T> to target four scalar fused multiply-adds.
|
||||
@ -1499,6 +1569,86 @@ struct multiply_add<Array<bfloat16_t, N>, Array<bfloat16_t, N>, Array<bfloat16_t
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator+(Array<T, N> const &lhs, Array<T, N> const &rhs) {
|
||||
plus<Array<T, N>> op;
|
||||
return op(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator-(Array<T, N> const &lhs, Array<T, N> const &rhs) {
|
||||
minus<Array<T, N>> op;
|
||||
return op(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator-(Array<T, N> const &lhs) {
|
||||
negate<Array<T, N>> op;
|
||||
return op(lhs);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator*(Array<T, N> const &lhs, Array<T, N> const &rhs) {
|
||||
multiplies<Array<T, N>> op;
|
||||
return op(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator*(T lhs, Array<T, N> const &rhs) {
|
||||
multiplies<Array<T, N>> op;
|
||||
return op(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator*(Array<T, N> const &lhs, T rhs) {
|
||||
multiplies<Array<T, N>> op;
|
||||
return op(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator/(Array<T, N> const &lhs, Array<T, N> const &rhs) {
|
||||
divides<Array<T, N>> op;
|
||||
return op(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> fma(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) {
|
||||
multiply_add<Array<T, N>> op;
|
||||
return op(a, b, c);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> fma(T a, Array<T, N> const &b, Array<T, N> const &c) {
|
||||
multiply_add<Array<T, N>> op;
|
||||
return op(a, b, c);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> fma(Array<T, N> const &a, T b, Array<T, N> const &c) {
|
||||
multiply_add<Array<T, N>> op;
|
||||
return op(a, b, c);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> fma(Array<T, N> const &a, Array<T, N> const &b, T c) {
|
||||
multiply_add<Array<T, N>> op;
|
||||
return op(a, b, c);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -430,6 +430,25 @@ public:
|
||||
static_cast<int *>(workspace)
|
||||
};
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -461,30 +480,11 @@ public:
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
cudaError_t result;
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
result = cudaGetLastError();
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
@ -118,8 +118,15 @@ public:
|
||||
using WarpShape = typename GemmKernel::WarpShape;
|
||||
using InstructionShape = typename GemmKernel::InstructionShape;
|
||||
|
||||
using OperatorClass = typename GemmKernel::OperatorClass;
|
||||
using ArchTag = typename GemmKernel::ArchTag;
|
||||
// warp-level, arch-level (instruction), math operator
|
||||
using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator;
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
// Operator class and arch tag extract bottom-up
|
||||
// set it for top-level gemm device-level template
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
|
||||
// Type, layout, and complex transform deliberately exchanged with B
|
||||
using MapArguments = detail::MapArguments<
|
||||
|
||||
@ -312,6 +312,27 @@ public:
|
||||
static_cast<int *>(workspace)
|
||||
);
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -335,38 +356,31 @@ public:
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
|
||||
|
||||
//
|
||||
// Configure grid and block dimensions
|
||||
//
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
cudaError_t result;
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block
|
||||
<< "), SMEM: " << smem_size << " bytes");
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
result = cudaGetLastError();
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
|
||||
@ -49,6 +49,7 @@
|
||||
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/default_multistage_mma_complex.h"
|
||||
@ -112,6 +113,101 @@ struct DefaultGemmComplex;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Multiply-add operator
|
||||
// (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial
|
||||
>
|
||||
struct DefaultGemmComplex<
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassSimt,
|
||||
arch::Sm50, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassSimt,
|
||||
Stages,
|
||||
Operator,
|
||||
false,
|
||||
cutlass::arch::CacheOperation::Global,
|
||||
cutlass::arch::CacheOperation::Global,
|
||||
TransformA,
|
||||
TransformB
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB, LayoutB, 0,
|
||||
typename MmaCore::IteratorThreadMapB>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using Mma = cutlass::gemm::threadblock::MmaPipelined<
|
||||
typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
|
||||
IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
|
||||
layout::RowMajor, typename MmaCore::MmaPolicy>;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<
|
||||
ThreadblockShape,
|
||||
typename Mma::Operator,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
@ -170,6 +266,70 @@ struct DefaultGemmComplex<
|
||||
using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Multiply-add operator
|
||||
// (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial
|
||||
>
|
||||
struct DefaultGemmComplex<
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassSimt,
|
||||
arch::Sm80, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassSimt, arch::Sm80, ThreadblockShape,
|
||||
WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<
|
||||
ThreadblockShape,
|
||||
typename Mma::Operator,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
|
||||
@ -138,8 +138,20 @@ struct Gemm {
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D) {
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Mma::IteratorA::Layout,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
|
||||
layout::RowMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Mma::IteratorB::Layout,
|
||||
layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if (!TensorRef_aligned(ref_A, kAlignmentA)) {
|
||||
@ -274,7 +286,7 @@ struct Gemm {
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k());
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
|
||||
@ -582,7 +582,7 @@ public:
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k());
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
@ -302,8 +302,20 @@ public:
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Mma::IteratorA::Layout,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
|
||||
layout::RowMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Mma::IteratorB::Layout,
|
||||
layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
|
||||
@ -468,7 +480,7 @@ public:
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k());
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
@ -319,7 +319,7 @@ struct SparseGemm {
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k());
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
|
||||
@ -93,6 +93,9 @@ struct Mma_HFMA2 <
|
||||
/// C operand storage
|
||||
using FragmentC = Array<half_t, Shape::kMN>;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -179,6 +182,9 @@ struct Mma_HFMA2<
|
||||
/// C operand storage
|
||||
using FragmentC = Array<half_t, Shape::kMN>;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -270,6 +276,9 @@ struct Mma_HFMA2 <
|
||||
/// C operand storage
|
||||
using FragmentC = Array<half_t, Shape::kMN>;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -356,6 +365,8 @@ struct Mma_HFMA2<
|
||||
/// C operand storage
|
||||
using FragmentC = Array<half_t, Shape::kMN>;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -443,6 +454,9 @@ struct Mma_HFMA2 <
|
||||
/// C operand storage
|
||||
using FragmentC = Array<half_t, Shape::kMN>;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -533,6 +547,9 @@ struct Mma_HFMA2 <
|
||||
/// C operand storage
|
||||
using FragmentC = Array<half_t, Shape::kMN>;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -623,6 +640,9 @@ struct Mma_HFMA2 <
|
||||
/// C operand storage
|
||||
using FragmentC = Array<half_t, Shape::kMN>;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -714,6 +734,9 @@ struct Mma_HFMA2<
|
||||
/// C operand storage
|
||||
using FragmentC = Array<half_t, Shape::kMN>;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -800,6 +823,9 @@ struct Mma_HFMA2<
|
||||
/// C operand storage
|
||||
using FragmentC = Array<half_t, Shape::kMN>;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -879,6 +905,9 @@ struct Mma_HFMA2<
|
||||
/// C operand storage
|
||||
using FragmentC = Array<half_t, Shape::kMN>;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
@ -389,7 +389,7 @@ struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 1>, ElementA_,
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = MmaPolicy<
|
||||
MmaWarpSimt,
|
||||
MatrixShape<kPaddingN, 0>, // skew for A matrix to avoid SMEM bank conflicts
|
||||
MatrixShape<kPaddingM, 0>, // skew for A matrix to avoid SMEM bank conflicts
|
||||
MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts
|
||||
WarpCount::kK
|
||||
>;
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -1105,6 +1105,676 @@ struct DefaultMultistageMmaComplexCore<
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for complex double-precision
|
||||
///
|
||||
/// A: column-major
|
||||
/// B: row-major
|
||||
/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
typename RealA,
|
||||
typename RealB,
|
||||
typename RealC,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Complex transformation on operand A
|
||||
ComplexTransform TransformA,
|
||||
/// Complex transformation on operand B
|
||||
ComplexTransform TransformB,
|
||||
/// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
||||
typename Operator_,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultMultistageMmaComplexCore<
|
||||
Shape_, WarpShape_, GemmShape<1, 1, 1>,
|
||||
complex<RealA>, layout::ColumnMajor,
|
||||
complex<RealB>, layout::ColumnMajor,
|
||||
complex<RealC>, LayoutC_,
|
||||
arch::OpClassSimt,
|
||||
Stages,
|
||||
TransformA, TransformB,
|
||||
Operator_,
|
||||
CacheOpA, CacheOpB> {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = GemmShape<1, 1, 1>;
|
||||
using ElementA = complex<RealA>;
|
||||
using LayoutA = layout::ColumnMajor;
|
||||
using ElementB = complex<RealB>;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using ElementC = complex<RealC>;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
|
||||
|
||||
static_assert(WarpCount::kCount > 1,
|
||||
"This specialization requires at least two warps.");
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of access
|
||||
static int const kAccessSizeInBits = sizeof_bits<ElementA>::value;
|
||||
|
||||
/// No vectorized accesses
|
||||
static int const kElementsPerAccess = 1;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::ColumnMajor;
|
||||
|
||||
using SmemLayoutB = layout::RowMajor;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kM, Shape::kK>,
|
||||
kThreads,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
|
||||
IteratorThreadMapA>;
|
||||
|
||||
/// Policy of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kN>,
|
||||
kThreads,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Transpose the ThreadMap of iterator B
|
||||
using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt<IteratorThreadMapB>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
|
||||
SmemThreadMapB>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level op
|
||||
static const int WarpNumThreadsM = 4; // TODO need to extract these from template data
|
||||
static const int WarpNumThreadsN = 8;
|
||||
static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
|
||||
"WarpShape must be divisible by ThreadTile shape.");
|
||||
static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
|
||||
static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
|
||||
static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
|
||||
static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
|
||||
static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
|
||||
static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM);
|
||||
static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN);
|
||||
// these should have max of thread tile also
|
||||
using LaneMmaShape = cutlass::gemm::GemmShape<
|
||||
LaneM,
|
||||
LaneN,
|
||||
1>;
|
||||
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
|
||||
cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, // WarpShape
|
||||
cutlass::layout::RowMajorInterleaved<LaneLayout>, // LaneLayout
|
||||
LaneMmaShape
|
||||
>;
|
||||
|
||||
using MmaWarpSimt = cutlass::gemm::warp::MmaSimt<
|
||||
WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8
|
||||
ElementA, /// Data type of A elements
|
||||
SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout)
|
||||
ElementB, /// Data type of B elements
|
||||
SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout)
|
||||
ElementC, /// Element type of C matrix
|
||||
LayoutC, /// Layout of C matrix (concept: MatrixLayout)
|
||||
Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
>; /// Used for partial specialization
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = MmaPolicy<
|
||||
MmaWarpSimt,
|
||||
MatrixShape<0, 0>,
|
||||
MatrixShape<0, Shape::kK / 32>,
|
||||
WarpCount::kK>;
|
||||
};
|
||||
|
||||
/// Partial specialization for complex double-precision
|
||||
///
|
||||
/// A: column-major
|
||||
/// B: row-major
|
||||
/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
typename RealA,
|
||||
typename RealB,
|
||||
typename RealC,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Complex transformation on operand A
|
||||
ComplexTransform TransformA,
|
||||
/// Complex transformation on operand B
|
||||
ComplexTransform TransformB,
|
||||
/// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
||||
typename Operator_,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultMultistageMmaComplexCore<
|
||||
Shape_, WarpShape_, GemmShape<1, 1, 1>,
|
||||
complex<RealA>, layout::ColumnMajor,
|
||||
complex<RealB>, layout::RowMajor,
|
||||
complex<RealC>, LayoutC_,
|
||||
arch::OpClassSimt,
|
||||
Stages,
|
||||
TransformA, TransformB,
|
||||
Operator_,
|
||||
CacheOpA, CacheOpB> {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = GemmShape<1, 1, 1>;
|
||||
using ElementA = complex<RealA>;
|
||||
using LayoutA = layout::ColumnMajor;
|
||||
using ElementB = complex<RealB>;
|
||||
using LayoutB = layout::RowMajor;
|
||||
using ElementC = complex<RealC>;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
|
||||
|
||||
static_assert(WarpCount::kCount > 1,
|
||||
"This specialization requires at least two warps.");
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of access
|
||||
static int const kAccessSizeInBits = sizeof_bits<ElementA>::value;
|
||||
|
||||
/// No vectorized accesses
|
||||
static int const kElementsPerAccess = 1;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::ColumnMajor;
|
||||
|
||||
using SmemLayoutB = layout::RowMajor;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kM, Shape::kK>,
|
||||
kThreads,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
|
||||
IteratorThreadMapA>;
|
||||
|
||||
/// Policy of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kN, Shape::kK>,
|
||||
kThreads,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
|
||||
IteratorThreadMapB>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level op
|
||||
static const int WarpNumThreadsM = 4; // TODO need to extract these from template data
|
||||
static const int WarpNumThreadsN = 8;
|
||||
static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
|
||||
"WarpShape must be divisible by ThreadTile shape.");
|
||||
static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
|
||||
static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
|
||||
static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
|
||||
static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
|
||||
static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
|
||||
static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM);
|
||||
static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN);
|
||||
// these should have max of thread tile also
|
||||
using LaneMmaShape = cutlass::gemm::GemmShape<
|
||||
LaneM,
|
||||
LaneN,
|
||||
1>;
|
||||
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
|
||||
cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, // WarpShape
|
||||
cutlass::layout::RowMajorInterleaved<LaneLayout>, // LaneLayout
|
||||
LaneMmaShape
|
||||
>;
|
||||
|
||||
using MmaWarpSimt = cutlass::gemm::warp::MmaSimt<
|
||||
WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8
|
||||
ElementA, /// Data type of A elements
|
||||
SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout)
|
||||
ElementB, /// Data type of B elements
|
||||
SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout)
|
||||
ElementC, /// Element type of C matrix
|
||||
LayoutC, /// Layout of C matrix (concept: MatrixLayout)
|
||||
Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
>; /// Used for partial specialization
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = MmaPolicy<
|
||||
MmaWarpSimt,
|
||||
MatrixShape<0, 0>,
|
||||
MatrixShape<0, 0>, // or Shape::kK / 32
|
||||
WarpCount::kK>;
|
||||
};
|
||||
|
||||
/// Partial specialization for complex double-precision
|
||||
///
|
||||
/// A: column-major
|
||||
/// B: row-major
|
||||
/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
typename RealA,
|
||||
typename RealB,
|
||||
typename RealC,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Complex transformation on operand A
|
||||
ComplexTransform TransformA,
|
||||
/// Complex transformation on operand B
|
||||
ComplexTransform TransformB,
|
||||
/// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
||||
typename Operator_,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultMultistageMmaComplexCore<
|
||||
Shape_, WarpShape_, GemmShape<1, 1, 1>,
|
||||
complex<RealA>, layout::RowMajor,
|
||||
complex<RealB>, layout::ColumnMajor,
|
||||
complex<RealC>, LayoutC_,
|
||||
arch::OpClassSimt,
|
||||
Stages,
|
||||
TransformA, TransformB,
|
||||
Operator_,
|
||||
CacheOpA, CacheOpB> {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = GemmShape<1, 1, 1>;
|
||||
using ElementA = complex<RealA>;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using ElementB = complex<RealB>;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using ElementC = complex<RealC>;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
|
||||
|
||||
static_assert(WarpCount::kCount > 1,
|
||||
"This specialization requires at least two warps.");
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of access
|
||||
static int const kAccessSizeInBits = sizeof_bits<ElementA>::value;
|
||||
|
||||
/// No vectorized accesses
|
||||
static int const kElementsPerAccess = 1;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::ColumnMajor;
|
||||
|
||||
using SmemLayoutB = layout::RowMajor;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kM>,
|
||||
kThreads,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Transpose the ThreadMap of iterator A
|
||||
using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt<IteratorThreadMapA>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
|
||||
SmemThreadMapA>;
|
||||
|
||||
/// Policy of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kN>,
|
||||
kThreads,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Transpose the ThreadMap of iterator B
|
||||
using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt<IteratorThreadMapB>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
|
||||
SmemThreadMapB>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level op
|
||||
static const int WarpNumThreadsM = 4; // TODO need to extract these from template data
|
||||
static const int WarpNumThreadsN = 8;
|
||||
static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
|
||||
"WarpShape must be divisible by ThreadTile shape.");
|
||||
static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
|
||||
static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
|
||||
static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
|
||||
static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
|
||||
static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
|
||||
static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM);
|
||||
static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN);
|
||||
// these should have max of thread tile also
|
||||
using LaneMmaShape = cutlass::gemm::GemmShape<
|
||||
LaneM,
|
||||
LaneN,
|
||||
1>;
|
||||
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
|
||||
cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, // WarpShape
|
||||
cutlass::layout::RowMajorInterleaved<LaneLayout>, // LaneLayout
|
||||
LaneMmaShape
|
||||
>;
|
||||
|
||||
using MmaWarpSimt = cutlass::gemm::warp::MmaSimt<
|
||||
WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8
|
||||
ElementA, /// Data type of A elements
|
||||
SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout)
|
||||
ElementB, /// Data type of B elements
|
||||
SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout)
|
||||
ElementC, /// Element type of C matrix
|
||||
LayoutC, /// Layout of C matrix (concept: MatrixLayout)
|
||||
Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
>; /// Used for partial specialization
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = MmaPolicy<
|
||||
MmaWarpSimt,
|
||||
MatrixShape<Shape::kK / 32, 0>,
|
||||
MatrixShape<0, Shape::kK / 32>,
|
||||
WarpCount::kK>;
|
||||
};
|
||||
|
||||
/// Partial specialization for complex double-precision
|
||||
///
|
||||
/// A: column-major
|
||||
/// B: row-major
|
||||
/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
typename RealA,
|
||||
typename RealB,
|
||||
typename RealC,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Complex transformation on operand A
|
||||
ComplexTransform TransformA,
|
||||
/// Complex transformation on operand B
|
||||
ComplexTransform TransformB,
|
||||
/// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
||||
typename Operator_,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultMultistageMmaComplexCore<
|
||||
Shape_, WarpShape_, GemmShape<1, 1, 1>,
|
||||
complex<RealA>, layout::RowMajor,
|
||||
complex<RealB>, layout::RowMajor,
|
||||
complex<RealC>, LayoutC_,
|
||||
arch::OpClassSimt,
|
||||
Stages,
|
||||
TransformA, TransformB,
|
||||
Operator_,
|
||||
CacheOpA, CacheOpB> {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = GemmShape<1, 1, 1>;
|
||||
using ElementA = complex<RealA>;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using ElementB = complex<RealB>;
|
||||
using LayoutB = layout::RowMajor;
|
||||
using ElementC = complex<RealC>;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
|
||||
|
||||
static_assert(WarpCount::kCount > 1,
|
||||
"This specialization requires at least two warps.");
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of access
|
||||
static int const kAccessSizeInBits = sizeof_bits<ElementA>::value;
|
||||
|
||||
/// No vectorized accesses
|
||||
static int const kElementsPerAccess = 1;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::ColumnMajor;
|
||||
|
||||
using SmemLayoutB = layout::RowMajor;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kM>,
|
||||
kThreads,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Transpose the ThreadMap of iterator A
|
||||
using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt<IteratorThreadMapA>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
|
||||
SmemThreadMapA>;
|
||||
|
||||
/// Policy of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kN, Shape::kK>,
|
||||
kThreads,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
|
||||
IteratorThreadMapB>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level op
|
||||
static const int WarpNumThreadsM = 4; // TODO need to extract these from template data
|
||||
static const int WarpNumThreadsN = 8;
|
||||
static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
|
||||
"WarpShape must be divisible by ThreadTile shape.");
|
||||
static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
|
||||
static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
|
||||
static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
|
||||
static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
|
||||
static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
|
||||
static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM);
|
||||
static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN);
|
||||
// these should have max of thread tile also
|
||||
using LaneMmaShape = cutlass::gemm::GemmShape<
|
||||
LaneM,
|
||||
LaneN,
|
||||
1>;
|
||||
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
|
||||
cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, // WarpShape
|
||||
cutlass::layout::RowMajorInterleaved<LaneLayout>, // LaneLayout
|
||||
LaneMmaShape
|
||||
>;
|
||||
|
||||
using MmaWarpSimt = cutlass::gemm::warp::MmaSimt<
|
||||
WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8
|
||||
ElementA, /// Data type of A elements
|
||||
SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout)
|
||||
ElementB, /// Data type of B elements
|
||||
SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout)
|
||||
ElementC, /// Element type of C matrix
|
||||
LayoutC, /// Layout of C matrix (concept: MatrixLayout)
|
||||
Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
>; /// Used for partial specialization
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = MmaPolicy<
|
||||
MmaWarpSimt,
|
||||
MatrixShape<Shape::kK / 32, 0>,
|
||||
MatrixShape<0, 0>, // or Shape::kK / 32
|
||||
WarpCount::kK>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
|
||||
@ -228,7 +228,7 @@ public:
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
@ -258,7 +258,7 @@ public:
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
@ -514,6 +514,11 @@ public:
|
||||
|
||||
}
|
||||
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -105,6 +105,14 @@ public:
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for MmaSingleStage is 1 (single stage mma pipeline)
|
||||
static_assert((Base::kStages==1), "MmaSingleStage requires kStages set to value 1");
|
||||
private:
|
||||
|
||||
@ -314,8 +314,17 @@ public:
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicyTensorOp)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Policy::Operator;
|
||||
|
||||
/// Architecture tag from underlying instruction
|
||||
using ArchTag = typename ArchMmaOperator::ArchTag;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename Policy::Operator::Shape;
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
@ -323,9 +332,6 @@ public:
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
@ -337,7 +343,7 @@ public:
|
||||
Operand::kA,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,
|
||||
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
|
||||
Policy::OpDelta::kRow,
|
||||
32,
|
||||
1
|
||||
@ -355,7 +361,7 @@ public:
|
||||
Operand::kB,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
|
||||
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
|
||||
Policy::OpDelta::kColumn,
|
||||
32,
|
||||
1
|
||||
@ -368,14 +374,14 @@ public:
|
||||
using TransformedFragmentB = FragmentB;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % Policy::Operator::Shape::kM) &&
|
||||
!(Shape::kN % Policy::Operator::Shape::kN),
|
||||
!(Shape::kM % ArchMmaOperator::Shape::kM) &&
|
||||
!(Shape::kN % ArchMmaOperator::Shape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<
|
||||
Shape::kM / Policy::Operator::Shape::kM,
|
||||
Shape::kN / Policy::Operator::Shape::kN
|
||||
Shape::kM / ArchMmaOperator::Shape::kM,
|
||||
Shape::kN / ArchMmaOperator::Shape::kN
|
||||
>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
@ -383,7 +389,7 @@ public:
|
||||
MatrixShape<Shape::kM, Shape::kN>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
typename Policy::Operator::Shape,
|
||||
typename ArchMmaOperator::Shape,
|
||||
typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this
|
||||
@ -393,7 +399,7 @@ public:
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
static_assert(
|
||||
FragmentC::kElements == 2 * MmaIterations::kCount * Policy::Operator::FragmentC::kElements,
|
||||
FragmentC::kElements == 2 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements,
|
||||
"Unexpected planar complex fragment length.");
|
||||
|
||||
private:
|
||||
@ -403,7 +409,7 @@ private:
|
||||
//
|
||||
|
||||
/// Underlying real-valued matrix multiply operator (concept: arch::Mma)
|
||||
typename Policy::Operator mma;
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
|
||||
@ -425,9 +431,9 @@ public:
|
||||
) const {
|
||||
|
||||
// Alias types for underlying real-valued matrix multiply operator
|
||||
using MmaOperandA = typename Policy::Operator::FragmentA;
|
||||
using MmaOperandB = typename Policy::Operator::FragmentB;
|
||||
using MmaOperandC = typename Policy::Operator::FragmentC;
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
static_assert(MmaOperandA::kElements == 1,
|
||||
"This implementation only supports math instructions in which exactly one element is needed for the A operand."
|
||||
@ -600,11 +606,17 @@ public:
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Policy::Operator;
|
||||
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename Policy::Operator::Shape;
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Underlying arch tag
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
using ArchTag = typename ArchMmaOperator::ArchTag;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
@ -612,9 +624,6 @@ public:
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
@ -626,7 +635,7 @@ public:
|
||||
Operand::kA,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,
|
||||
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
|
||||
Policy::OpDelta::kRow,
|
||||
32,
|
||||
1
|
||||
@ -637,7 +646,7 @@ public:
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA =
|
||||
Array<typename Policy::Operator::ElementA, FragmentA::kElements * 2>;
|
||||
Array<typename ArchMmaOperator::ElementA, FragmentA::kElements * 2>;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<
|
||||
@ -645,7 +654,7 @@ public:
|
||||
Operand::kB,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
|
||||
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
|
||||
Policy::OpDelta::kColumn,
|
||||
32,
|
||||
1
|
||||
@ -656,17 +665,17 @@ public:
|
||||
|
||||
/// Storage for transformed B tile
|
||||
using TransformedFragmentB =
|
||||
Array<typename Policy::Operator::ElementB, FragmentB::kElements * 2>;
|
||||
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements * 2>;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % Policy::Operator::Shape::kM) &&
|
||||
!(Shape::kN % Policy::Operator::Shape::kN),
|
||||
!(Shape::kM % ArchMmaOperator::Shape::kM) &&
|
||||
!(Shape::kN % ArchMmaOperator::Shape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
/// Number of complex products operations performed (one complex product needs four mma instructions)
|
||||
using MmaIterations = MatrixShape<
|
||||
Shape::kM / Policy::Operator::Shape::kM,
|
||||
Shape::kN / Policy::Operator::Shape::kN
|
||||
Shape::kM / ArchMmaOperator::Shape::kM,
|
||||
Shape::kN / ArchMmaOperator::Shape::kN
|
||||
>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
@ -674,7 +683,7 @@ public:
|
||||
MatrixShape<Shape::kM, Shape::kN>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
typename Policy::Operator::Shape,
|
||||
typename ArchMmaOperator::Shape,
|
||||
typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this
|
||||
@ -690,7 +699,7 @@ private:
|
||||
//
|
||||
|
||||
/// Underlying real-valued matrix multiply operator (concept: arch::Mma)
|
||||
typename Policy::Operator mma;
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
|
||||
@ -712,11 +721,11 @@ public:
|
||||
) const {
|
||||
|
||||
// Alias types for underlying real-valued matrix multiply operator
|
||||
using InstMmaOperandA = typename Policy::Operator::FragmentA;
|
||||
using InstMmaOperandB = typename Policy::Operator::FragmentB;
|
||||
using MmaOperandC = typename Policy::Operator::FragmentC;
|
||||
using InstMmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using InstMmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
static_assert(platform::is_same<cutlass::gemm::GemmShape<16, 8, 8>, typename Policy::Operator::Shape>::value,
|
||||
static_assert(platform::is_same<cutlass::gemm::GemmShape<16, 8, 8>, typename ArchMmaOperator::Shape>::value,
|
||||
"This implementation only supports MMA.1688 math instructions.");
|
||||
|
||||
static_assert(InstMmaOperandA::kElements == 4,
|
||||
@ -794,8 +803,8 @@ public:
|
||||
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
|
||||
FragmentA const &A, FragmentB const &B) const {
|
||||
// Alias types for underlying real-valued matrix multiply operator
|
||||
using InstMmaOperandA = typename Policy::Operator::FragmentA;
|
||||
using InstMmaOperandB = typename Policy::Operator::FragmentB;
|
||||
using InstMmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using InstMmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
|
||||
//
|
||||
// Define conversions from source type to instruction operands' type
|
||||
|
||||
@ -147,11 +147,17 @@ public:
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename Policy::Operator::Shape;
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Policy::Operator;
|
||||
|
||||
/// Underlying architecture tag
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Underlying arch tag
|
||||
using ArchTag = typename ArchMmaOperator::ArchTag;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
@ -159,8 +165,6 @@ public:
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
@ -173,7 +177,7 @@ public:
|
||||
Operand::kA,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,
|
||||
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
|
||||
Policy::OpDelta::kRow,
|
||||
32,
|
||||
1
|
||||
@ -191,7 +195,7 @@ public:
|
||||
Operand::kB,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
|
||||
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
|
||||
Policy::OpDelta::kColumn,
|
||||
32,
|
||||
1
|
||||
@ -204,14 +208,14 @@ public:
|
||||
using TransformedFragmentB = FragmentB;
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % Policy::Operator::Shape::kM) &&
|
||||
!(Shape::kN % Policy::Operator::Shape::kN),
|
||||
!(Shape::kM % ArchMmaOperator::Shape::kM) &&
|
||||
!(Shape::kN % ArchMmaOperator::Shape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<
|
||||
Shape::kM / Policy::Operator::Shape::kM,
|
||||
Shape::kN / Policy::Operator::Shape::kN
|
||||
Shape::kM / ArchMmaOperator::Shape::kM,
|
||||
Shape::kN / ArchMmaOperator::Shape::kN
|
||||
>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
@ -219,7 +223,7 @@ public:
|
||||
MatrixShape<Shape::kM, Shape::kN>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
typename Policy::Operator::Shape,
|
||||
typename ArchMmaOperator::Shape,
|
||||
typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this
|
||||
@ -229,7 +233,7 @@ public:
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
static_assert(
|
||||
FragmentC::kElements == 3 * MmaIterations::kCount * Policy::Operator::FragmentC::kElements,
|
||||
FragmentC::kElements == 3 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements,
|
||||
"Unexpected gaussian complex fragment length.");
|
||||
|
||||
private:
|
||||
@ -239,7 +243,7 @@ private:
|
||||
//
|
||||
|
||||
/// Underlying real-valued matrix multiply operator (concept: arch::Mma)
|
||||
typename Policy::Operator mma;
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
|
||||
@ -261,9 +265,9 @@ public:
|
||||
) const {
|
||||
|
||||
// Alias types for underlying real-valued matrix multiply operator
|
||||
using MmaOperandA = typename Policy::Operator::FragmentA;
|
||||
using MmaOperandB = typename Policy::Operator::FragmentB;
|
||||
using MmaOperandC = typename Policy::Operator::FragmentC;
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
static_assert(MmaOperandA::kElements == 1,
|
||||
"This implementation only supports math instructions in which exactly one element is needed for the A operand."
|
||||
@ -346,8 +350,6 @@ public:
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// TODO - partial specializations of real*complex and complex*real
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
|
||||
@ -68,6 +68,10 @@ template <
|
||||
typename Policy_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK = 1,
|
||||
/// Complex transformation on operand A
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Complex transformation on operand B
|
||||
ComplexTransform TransformB = ComplexTransform::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
@ -104,10 +108,10 @@ public:
|
||||
using ArchTag = arch::Sm50;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
/// Layout of threads
|
||||
using ThreadLayoutA = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA >::value,
|
||||
@ -215,12 +219,22 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
FragmentC &d,
|
||||
FragmentA const &a,
|
||||
FragmentB const &b,
|
||||
FragmentA a,
|
||||
FragmentB b,
|
||||
FragmentC const &c, int group_idx = 0) const {
|
||||
|
||||
ThreadMma mma;
|
||||
|
||||
if (kTransformA == ComplexTransform::kConjugate) {
|
||||
conjugate<FragmentA> conj_a;
|
||||
a = conj_a(a);
|
||||
}
|
||||
|
||||
if (kTransformB == ComplexTransform::kConjugate) {
|
||||
conjugate<FragmentB> conj_b;
|
||||
b = conj_b(b);
|
||||
}
|
||||
|
||||
mma(d, a, b, c);
|
||||
}
|
||||
|
||||
|
||||
@ -111,17 +111,28 @@ public:
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Equivalant base dense mma
|
||||
using Base = MmaTensorOp<Shape, ElementA, LayoutA, ElementB, LayoutB,
|
||||
ElementC, LayoutC, Policy, PartitionsK_,
|
||||
AccumulatorsInRowMajor, Enable>;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Base::ArchMmaOperator;
|
||||
|
||||
/// Architecture tag from underlying instruction
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
using ArchTag = typename Base::ArchTag;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
using OperatorClass = typename Base::OperatorClass;
|
||||
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename Base::InstructionShape;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformA = Base::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = Base::kTransformB;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
@ -171,25 +182,19 @@ public:
|
||||
Array<typename Policy::Operator::ElementA, FragmentA::kElements>;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
|
||||
MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
|
||||
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
using IteratorB = typename Base::IteratorB;
|
||||
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
using FragmentB = typename Base::FragmentB;
|
||||
|
||||
/// Storage for transformed B tile
|
||||
using TransformedFragmentB =
|
||||
Array<typename Policy::Operator::ElementB, FragmentB::kElements>;
|
||||
using TransformedFragmentB = typename Base::TransformedFragmentB;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaTensorOpAccumulatorTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
|
||||
typename Policy::Operator::Shape, typename Policy::OpDelta>;
|
||||
using IteratorC = typename Base::IteratorC;
|
||||
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
using FragmentC = typename Base::FragmentC;
|
||||
|
||||
/// Iterates over the E operand in memory
|
||||
using IteratorE = SparseMmaTensorOpMetaTileIterator<
|
||||
@ -204,23 +209,13 @@ public:
|
||||
/// Storage for E tile
|
||||
using FragmentE = typename IteratorE::Fragment;
|
||||
|
||||
private:
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % Policy::Operator::Shape::kM) &&
|
||||
!(Shape::kN % Policy::Operator::Shape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<
|
||||
Shape::kM / Policy::Operator::Shape::kM,
|
||||
Shape::kN / Policy::Operator::Shape::kN
|
||||
>;
|
||||
using MmaIterations = typename Base::MmaIterations;
|
||||
|
||||
public:
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
typename Policy::Operator mma;
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
|
||||
@ -299,21 +294,21 @@ public:
|
||||
// Define conversions from source type to instruction type
|
||||
//
|
||||
FloatRoundStyle const kRoundA =
|
||||
PreferredRoundingMode<typename Policy::Operator::ElementA,
|
||||
PreferredRoundingMode<typename ArchMmaOperator::ElementA,
|
||||
ElementA>::kRound;
|
||||
FloatRoundStyle const kRoundB =
|
||||
PreferredRoundingMode<typename Policy::Operator::ElementB,
|
||||
PreferredRoundingMode<typename ArchMmaOperator::ElementB,
|
||||
ElementB>::kRound;
|
||||
detail::ConvertAndPack<typename Policy::Operator::ElementA, ElementA,
|
||||
detail::ConvertAndPack<typename ArchMmaOperator::ElementA, ElementA,
|
||||
FragmentA::kElements / 2, kRoundA>
|
||||
convert_A;
|
||||
NumericArrayConverter<typename Policy::Operator::ElementB, ElementB,
|
||||
NumericArrayConverter<typename ArchMmaOperator::ElementB, ElementB,
|
||||
FragmentB::kElements, kRoundB>
|
||||
convert_B;
|
||||
Array<ElementA, FragmentA::kElements / 2> const *ptr_A =
|
||||
reinterpret_cast<Array<ElementA, FragmentA::kElements / 2> const *>(&A);
|
||||
Array<typename Policy::Operator::ElementA, FragmentA::kElements / 2> *
|
||||
ptr_dst_A = reinterpret_cast<Array<typename Policy::Operator::ElementA,
|
||||
Array<typename ArchMmaOperator::ElementA, FragmentA::kElements / 2> *
|
||||
ptr_dst_A = reinterpret_cast<Array<typename ArchMmaOperator::ElementA,
|
||||
FragmentA::kElements / 2> *>(&dst_A);
|
||||
|
||||
dst_B = convert_B(B);
|
||||
|
||||
@ -244,8 +244,6 @@ public:
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
private:
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<
|
||||
(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
|
||||
|
||||
@ -1518,6 +1518,7 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
} else if (Layout::kFactor == 2) {
|
||||
// Super Matrix multiply kBlock = 32
|
||||
if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) {
|
||||
// Matrix multiply 1688 A/B
|
||||
// (Q stands for 1 8x128bit block).
|
||||
// Q0
|
||||
// Q1
|
||||
@ -3232,6 +3233,426 @@ public:
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store
|
||||
/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major
|
||||
/// accumulator layout.
|
||||
///
|
||||
/// Satisfies:
|
||||
/// ReadableRandomAccessContiguousTileIteratorConcept |
|
||||
/// WriteableRandomAccessContiguousTileIteratorConcept
|
||||
///
|
||||
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Element typ
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Interval between adjacent *MMA instructions (in units of MMA
|
||||
/// instructions, concept: MatrixShape)
|
||||
typename OpDelta_,
|
||||
/// Interleaved N
|
||||
int InterleavedN>
|
||||
class MmaTensorOpAccumulatorTileIterator<
|
||||
Shape_, Element_, cutlass::layout::TensorNCxHWx<InterleavedN>,
|
||||
InstructionShape_, OpDelta_> {
|
||||
public:
|
||||
|
||||
/// Shape of tile to load (concept: MatrixShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Operand tag
|
||||
static Operand const kOperand = Operand::kC;
|
||||
|
||||
/// Element type
|
||||
using Element = int8_t;
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = cutlass::layout::TensorNCxHWx<InterleavedN>;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
|
||||
using OpDelta = OpDelta_;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Internal structure of iterator - made public to enable introspection
|
||||
struct Policy {
|
||||
static_assert(
|
||||
!(Shape::kRow % InstructionShape::kM) &&
|
||||
!(Shape::kColumn % InstructionShape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
/// Number of elements in strided dimension that each STG writes
|
||||
static int const kStridedPerSTG = 8;
|
||||
|
||||
/// Factor to calculate reorder index to pack accumulator.
|
||||
static int const kPackedFactor = Shape::kColumn / 32;
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<Shape::kRow / kStridedPerSTG,
|
||||
Shape::kColumn / InterleavedN>;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
static int const kElementsPerAccess = InterleavedN / 4;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
struct alignas((kElementsPerAccess * sizeof_bits<Element>::value / 8)) AccessType {
|
||||
Array<Element, kElementsPerAccess> storage;
|
||||
};
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<int32_t, Shape::kCount / kThreads>;
|
||||
|
||||
private:
|
||||
|
||||
/// Reference to output tensor
|
||||
TensorRef ref_;
|
||||
|
||||
/// Row offset index globally
|
||||
LongIndex global_offset_row_;
|
||||
|
||||
/// Column offset index globally
|
||||
LongIndex global_offset_col_;
|
||||
|
||||
/// Output tensor size
|
||||
TensorCoord extent_;
|
||||
|
||||
/// Alpha
|
||||
float alpha_;
|
||||
|
||||
/// Beta
|
||||
float beta_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpAccumulatorTileIterator() { }
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpAccumulatorTileIterator(
|
||||
TensorRef const &ref,
|
||||
int const lane_id,
|
||||
TensorCoord extent,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f
|
||||
):
|
||||
ref_(ref),
|
||||
extent_(extent),
|
||||
alpha_(alpha),
|
||||
beta_(beta) {
|
||||
|
||||
int quad = (lane_id >> 2);
|
||||
int lane_in_quad = (lane_id & 3);
|
||||
|
||||
global_offset_row_ = quad;
|
||||
|
||||
global_offset_col_ = lane_in_quad * kElementsPerAccess;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset to internal pointer(s) to advance through memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) {
|
||||
ref_.add_pointer_offset(offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpAccumulatorTileIterator &add_tile_offset(MatrixCoord const &tile_offset) {
|
||||
|
||||
global_offset_row_ += tile_offset.row() * Shape::kRow;
|
||||
|
||||
global_offset_col_ += tile_offset.column() * Shape::kColumn;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpAccumulatorTileIterator & operator++() {
|
||||
// deliberate no-op
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaTensorOpAccumulatorTileIterator & operator--() {
|
||||
// deliberate no-op
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(-tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
load_with_pointer_offset(frag);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(
|
||||
Fragment &frag, ///< fragment to load from the tensor
|
||||
Index pointer_offset) const { ///< loads a tile with a linear offset
|
||||
|
||||
TensorRef offset_ref(ref_);
|
||||
offset_ref.add_pointer_offset(pointer_offset);
|
||||
|
||||
AccessType* frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n = 0; mma_n < Policy::MmaIterations::kN; ++mma_n) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_m = 0; mma_m < Policy::MmaIterations::kM; ++mma_m) {
|
||||
int accum_m = mma_m * InstructionShape::kM;
|
||||
int accum_n = mma_n * InstructionShape::kN;
|
||||
|
||||
int idx = mma_m + mma_n * Policy::MmaIterations::kM;
|
||||
|
||||
AccessType* access_ptr = reinterpret_cast<AccessType *>(offset_ref.data() +
|
||||
accum_m * offset_ref.stride(0) + accum_n);
|
||||
|
||||
frag_ptr[idx] = access_ptr[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
Fragment &frag, ///< fragment to load from the tensor
|
||||
Index byte_offset) const { ///< loads a tile with a linear offset
|
||||
|
||||
load_with_pointer_offset(byte_offset / sizeof(Element));
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
Fragment &frag, ///< fragment to load from the tensor
|
||||
TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles
|
||||
|
||||
load(frag, tile_offset, 0);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
Fragment &frag, ///< fragment to load from the tensor
|
||||
TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles
|
||||
Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset
|
||||
|
||||
load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset);
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store(Fragment const &frag) const {
|
||||
store_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory with additional pointer offset
|
||||
CUTLASS_DEVICE
|
||||
void store_with_pointer_offset(
|
||||
Fragment const &frag, ///< fragment to store from the tensor
|
||||
Index pointer_offset) const { ///< store a tile with a linear offset
|
||||
|
||||
TensorRef offset_ref(ref_);
|
||||
offset_ref.add_pointer_offset(pointer_offset);
|
||||
|
||||
Array<float, Shape::kCount / kThreads> output_frag_f;
|
||||
Array<Element, Shape::kCount / kThreads> output_frag;
|
||||
|
||||
LongIndex pq = extent_.h() * extent_.w();
|
||||
|
||||
LongIndex extent_row = extent_.n() * pq;
|
||||
LongIndex extent_col = extent_.c();
|
||||
|
||||
LongIndex k_major = (global_offset_col_ / InterleavedN) * pq;
|
||||
Index k_minor = global_offset_col_ % InterleavedN;
|
||||
LongIndex k_offset = k_major * InterleavedN + k_minor;
|
||||
LongIndex k_offset_delta = pq * InterleavedN;
|
||||
|
||||
LongIndex stride_n = pq * extent_.c();
|
||||
|
||||
Index n;
|
||||
LongIndex pq_rem;
|
||||
|
||||
unsigned int pq_mul, pq_shr;
|
||||
find_divisor(pq_mul, pq_shr, pq);
|
||||
|
||||
if(beta_ == 0.0f) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < frag.size(); ++i) {
|
||||
output_frag_f[i] = frag[i];
|
||||
}
|
||||
|
||||
if(InstructionShape::kM == Policy::kStridedPerSTG) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < frag.size(); ++i) {
|
||||
output_frag[i] = (Element)(output_frag_f[i] * alpha_);
|
||||
}
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < frag.size(); ++i) {
|
||||
int map_i = (i / (16 * Policy::kPackedFactor)) * (16 * Policy::kPackedFactor)
|
||||
+ (i % (8 * Policy::kPackedFactor)) / 2 * 4
|
||||
+ (i % (8 * Policy::kPackedFactor)) % 2
|
||||
+ (i / (8 * Policy::kPackedFactor)) % 2 * 2;
|
||||
output_frag[i] = (Element)(output_frag_f[map_i] * alpha_);
|
||||
}
|
||||
}
|
||||
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const*>(&output_frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
|
||||
int accum_m = mma_m * Policy::kStridedPerSTG;
|
||||
|
||||
fast_divmod(n, pq_rem, global_offset_row_ + accum_m, pq, pq_mul, pq_shr);
|
||||
LongIndex offset_m = n * stride_n + k_offset + pq_rem * InterleavedN;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
|
||||
|
||||
int accum_n = mma_n * InterleavedN;
|
||||
|
||||
int idx = mma_n + mma_m * Policy::MmaIterations::kColumn;
|
||||
|
||||
if((global_offset_row_ + accum_m < extent_row) && (global_offset_col_ + accum_n < extent_col)) {
|
||||
AccessType* access_ptr = reinterpret_cast<AccessType *>(offset_ref.data() +
|
||||
offset_m + mma_n * k_offset_delta);
|
||||
|
||||
access_ptr[0] = frag_ptr[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if(InstructionShape::kM == Policy::kStridedPerSTG) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < frag.size(); ++i) {
|
||||
output_frag_f[i] = frag[i];
|
||||
}
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < frag.size(); ++i) {
|
||||
int map_i = (i / (16 * Policy::kPackedFactor)) * (16 * Policy::kPackedFactor)
|
||||
+ (i % (8 * Policy::kPackedFactor)) / 2 * 4
|
||||
+ (i % (8 * Policy::kPackedFactor)) % 2
|
||||
+ (i / (8 * Policy::kPackedFactor)) % 2 * 2;
|
||||
output_frag_f[i] = frag[map_i];
|
||||
}
|
||||
}
|
||||
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const*>(&output_frag);
|
||||
|
||||
Array<Element, kElementsPerAccess> ref_frag;
|
||||
AccessType *ref_frag_ptr = reinterpret_cast<AccessType *>(&ref_frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
|
||||
int accum_m = mma_m * Policy::kStridedPerSTG;
|
||||
|
||||
fast_divmod(n, pq_rem, global_offset_row_ + accum_m, pq, pq_mul, pq_shr);
|
||||
LongIndex offset_m = n * stride_n + k_offset + pq_rem * InterleavedN;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
|
||||
|
||||
int accum_n = mma_n * InterleavedN;
|
||||
|
||||
int idx = mma_n + mma_m * Policy::MmaIterations::kColumn;
|
||||
|
||||
if((global_offset_row_ + accum_m < extent_row) && (global_offset_col_ + accum_n < extent_col)) {
|
||||
AccessType* access_ptr = reinterpret_cast<AccessType *>(offset_ref.data() +
|
||||
offset_m + mma_n * k_offset_delta);
|
||||
|
||||
ref_frag_ptr[0] = access_ptr[0];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < kElementsPerAccess; ++i) {
|
||||
output_frag[idx * kElementsPerAccess + i] = Element(alpha_ * output_frag_f[idx * kElementsPerAccess + i]
|
||||
+ beta_ * ref_frag[i]);
|
||||
}
|
||||
|
||||
access_ptr[0] = frag_ptr[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory with additional pointer offset
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(
|
||||
Fragment const &frag, ///< fragment to store from the tensor
|
||||
Index byte_offset) const { ///< store a tile with a linear offset
|
||||
|
||||
store_with_pointer_offset(byte_offset / sizeof(Element));
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void store(
|
||||
Fragment &frag, ///< fragment to store to the tensor
|
||||
TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles
|
||||
|
||||
store(frag, tile_offset, 0);
|
||||
}
|
||||
|
||||
/// Stores a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void store(
|
||||
/// fragment to store to the tensor
|
||||
Fragment const &frag,
|
||||
/// stores a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// stores a tile with a logical offset AND a pointer offset
|
||||
Index pointer_offset) const {
|
||||
store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -2243,6 +2243,847 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator specialized for 'TN' arrangement
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Operand identity
|
||||
Operand Operand_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
/// Layout of matrix operand
|
||||
typename Layout_,
|
||||
/// Shape of one matrix production operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||
/// MatrixShape)
|
||||
int OpDelta_,
|
||||
/// Number of threads participating in one matrix operation
|
||||
int Threads = 32,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_ = 1>
|
||||
class MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner {
|
||||
public:
|
||||
|
||||
/// Shape of tile to load (concept: MatrixShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Operand tag
|
||||
static Operand const kOperand = Operand_;
|
||||
|
||||
/// Basic check
|
||||
static_assert(kOperand == Operand::kA || kOperand== Operand::kB,
|
||||
"MmaVoltaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma.");
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = Layout_;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
|
||||
static int const kOpDelta = OpDelta_;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Number of elements accessed per Shared Memory load
|
||||
static int const kElementsPerAccess = 4;
|
||||
|
||||
private:
|
||||
|
||||
static int const kInterleavedTileRows = 32;
|
||||
static int const kInterleavedTileColumns = 32;
|
||||
static int const kInstructionsPerTile = 2;
|
||||
|
||||
/// Rounded up instruction counts
|
||||
using TileCount = MatrixShape<
|
||||
Shape::kRow / kInterleavedTileRows,
|
||||
Shape::kColumn / kInterleavedTileColumns
|
||||
>;
|
||||
|
||||
using FragmentCount = MatrixShape<
|
||||
TileCount::kRow * kInstructionsPerTile,
|
||||
TileCount::kColumn * kInstructionsPerTile
|
||||
>;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<
|
||||
Element,
|
||||
(kOperand == Operand::kA ? FragmentCount::kRow : FragmentCount::kColumn) * kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Memory access type
|
||||
using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
|
||||
/// Underlying tensor reference
|
||||
TensorRef ref_;
|
||||
|
||||
/// Extent of tensor
|
||||
MatrixCoord extent_;
|
||||
|
||||
/// Origin
|
||||
MatrixCoord origin_;
|
||||
|
||||
/// Used to conditionally enable extents checking
|
||||
bool divisible_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner(): divisible_(true) { }
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner(
|
||||
TensorRef const &ref,
|
||||
int lane_id
|
||||
):
|
||||
ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) {
|
||||
|
||||
int quad_id = lane_id / 4;
|
||||
int lane_in_quad = (lane_id % 4);
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
|
||||
int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile + lane_in_quad;
|
||||
int col_idx = 0;
|
||||
|
||||
origin_ = MatrixCoord(row_idx, col_idx);
|
||||
}
|
||||
else {
|
||||
|
||||
int row_idx = 0;
|
||||
int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile + lane_in_quad;
|
||||
|
||||
origin_ = MatrixCoord(row_idx, col_idx);
|
||||
}
|
||||
|
||||
ref_.add_coord_offset(origin_);
|
||||
}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner(
|
||||
TensorRef const &ref,
|
||||
TensorCoord extent,
|
||||
int lane_id
|
||||
): ref_(ref), extent_(extent), divisible_(false) {
|
||||
|
||||
int quad_id = lane_id / 4;
|
||||
int lane_in_quad = (lane_id % 4);
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
|
||||
int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile + lane_in_quad;
|
||||
int col_idx = 0;
|
||||
|
||||
origin_ = MatrixCoord(row_idx, col_idx);
|
||||
}
|
||||
else {
|
||||
|
||||
int row_idx = 0;
|
||||
int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile + lane_in_quad;
|
||||
|
||||
origin_ = MatrixCoord(row_idx, col_idx);
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
ref_.add_coord_offset(origin_);
|
||||
}
|
||||
|
||||
/// Adds a pointer offset to internal pointer(s) to advance through memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner &add_pointer_offset(LongIndex offset) {
|
||||
|
||||
ref_.add_pointer_offset(offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner &add_tile_offset(TensorCoord const &tile_offset) {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
origin_ += coord_offset;
|
||||
|
||||
ref_.add_coord_offset(coord_offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator++() {
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
add_tile_offset({0, 1});
|
||||
}
|
||||
else {
|
||||
add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator--() {
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
add_tile_offset({0, -1});
|
||||
}
|
||||
else {
|
||||
add_tile_offset({-1, 0});
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator+=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator-=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(-tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index pointer_offset) const {
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(ref_.data());
|
||||
int ldm = ref_.stride()[0];
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx = 0; idx < FragmentCount::kRow; ++idx) {
|
||||
|
||||
int tile_idx = idx / 2;
|
||||
int quad_idx = idx % 2;
|
||||
|
||||
int row_offset = tile_idx * kInterleavedTileRows + quad_idx * 4;
|
||||
frag_ptr[idx] = access_ptr[row_offset * ldm / kElementsPerAccess];
|
||||
}
|
||||
}
|
||||
else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx = 0; idx < FragmentCount::kColumn; ++idx) {
|
||||
|
||||
int tile_idx = idx / 2;
|
||||
int quad_idx = idx % 2;
|
||||
|
||||
int col_offset = tile_idx * kInterleavedTileColumns + quad_idx * 4;
|
||||
frag_ptr[idx] = access_ptr[col_offset * ldm / kElementsPerAccess];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index byte_offset) const {
|
||||
|
||||
load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits<Element>::value);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset) const {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
|
||||
load_with_pointer_offset(frag, ref_.offset(coord_offset));
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index pointer_offset) const {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
|
||||
load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index byte_offset) const {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
|
||||
load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits<Element>::value);
|
||||
}
|
||||
|
||||
/// Notify the iterator which k-group it is currently pointing to.
|
||||
///
|
||||
/// This does not advance the iterator. Rather, it overrides its internal
|
||||
/// tracking with constant-valued k-group index to enable the compiler to
|
||||
/// fold constants and achieve more efficient code.
|
||||
///
|
||||
/// This is used by some nontrivial permuted layouts.
|
||||
CUTLASS_DEVICE
|
||||
void set_kgroup_index(int k_group) {
|
||||
// no operation
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Tile iterator specialized for 'NT' arrangement
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Operand identity
|
||||
Operand Operand_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
/// Layout of matrix operand
|
||||
typename Layout_,
|
||||
/// Shape of one matrix production operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||
/// MatrixShape)
|
||||
int OpDelta_,
|
||||
/// Number of threads participating in one matrix operation
|
||||
int Threads = 32,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_ = 1>
|
||||
class MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter {
|
||||
public:
|
||||
|
||||
/// Shape of tile to load (concept: MatrixShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Operand tag
|
||||
static Operand const kOperand = Operand_;
|
||||
|
||||
/// Basic check
|
||||
static_assert(kOperand == Operand::kA || kOperand== Operand::kB,
|
||||
"MmaVoltaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma.");
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of source tile
|
||||
using Layout = Layout_;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape)
|
||||
static int const kOpDelta = OpDelta_;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = 32;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Number of elements accessed per Shared Memory load
|
||||
static int const kElementsPerAccess = 4;
|
||||
|
||||
private:
|
||||
|
||||
static int const kInterleavedTileRows = 32;
|
||||
static int const kInterleavedTileColumns = 32;
|
||||
static int const kInstructionsPerTile = 2;
|
||||
|
||||
/// Rounded up instruction counts
|
||||
using TileCount = MatrixShape<
|
||||
Shape::kRow / kInterleavedTileRows,
|
||||
Shape::kColumn / kInterleavedTileColumns
|
||||
>;
|
||||
|
||||
using FragmentCount = MatrixShape<
|
||||
TileCount::kRow * kInstructionsPerTile,
|
||||
TileCount::kColumn * kInstructionsPerTile
|
||||
>;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<
|
||||
Element,
|
||||
(kOperand == Operand::kA ? FragmentCount::kRow : FragmentCount::kColumn) * kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Memory access type
|
||||
using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
|
||||
/// Underlying tensor reference
|
||||
TensorRef ref_;
|
||||
|
||||
/// Extent of tensor
|
||||
MatrixCoord extent_;
|
||||
|
||||
/// Origin
|
||||
MatrixCoord origin_;
|
||||
|
||||
/// Used to conditionally enable extents checking
|
||||
bool divisible_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter(): divisible_(true) { }
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter(
|
||||
TensorRef const &ref,
|
||||
int lane_id
|
||||
):
|
||||
ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) {
|
||||
|
||||
int quad_id = lane_id / 4;
|
||||
int lane_in_quad = (lane_id % 4);
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
|
||||
int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile;
|
||||
int col_idx = lane_in_quad;
|
||||
|
||||
origin_ = MatrixCoord(row_idx, col_idx);
|
||||
}
|
||||
else {
|
||||
|
||||
int row_idx = lane_in_quad;
|
||||
int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile;
|
||||
|
||||
origin_ = MatrixCoord(row_idx, col_idx);
|
||||
}
|
||||
|
||||
ref_.add_coord_offset(origin_);
|
||||
}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter(
|
||||
TensorRef const &ref,
|
||||
TensorCoord extent,
|
||||
int lane_id
|
||||
): ref_(ref), extent_(extent), divisible_(false) {
|
||||
|
||||
int quad_id = lane_id / 4;
|
||||
int lane_in_quad = (lane_id % 4);
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
|
||||
int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile;
|
||||
int col_idx = lane_in_quad;
|
||||
|
||||
origin_ = MatrixCoord(row_idx, col_idx);
|
||||
}
|
||||
else {
|
||||
|
||||
int row_idx = lane_in_quad;
|
||||
int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile;
|
||||
|
||||
origin_ = MatrixCoord(row_idx, col_idx);
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
ref_.add_coord_offset(origin_);
|
||||
}
|
||||
|
||||
/// Adds a pointer offset to internal pointer(s) to advance through memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter &add_pointer_offset(LongIndex offset) {
|
||||
|
||||
ref_.add_pointer_offset(offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter &add_tile_offset(TensorCoord const &tile_offset) {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
origin_ += coord_offset;
|
||||
|
||||
ref_.add_coord_offset(coord_offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator++() {
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
add_tile_offset({0, 1});
|
||||
}
|
||||
else {
|
||||
add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator--() {
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
add_tile_offset({0, -1});
|
||||
}
|
||||
else {
|
||||
add_tile_offset({-1, 0});
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator+=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator-=(TensorCoord const &tile_offset) {
|
||||
add_tile_offset(-tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index pointer_offset) const {
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(ref_.data());
|
||||
int ldm = ref_.stride()[0];
|
||||
|
||||
if (kOperand == Operand::kA) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx = 0; idx < FragmentCount::kRow; ++idx) {
|
||||
|
||||
int tile_idx = idx / 2;
|
||||
int quad_idx = idx % 2;
|
||||
|
||||
int row_offset = tile_idx * kInterleavedTileRows;
|
||||
frag_ptr[idx] = access_ptr[row_offset / kElementsPerAccess + quad_idx];
|
||||
}
|
||||
}
|
||||
else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx = 0; idx < FragmentCount::kColumn; ++idx) {
|
||||
|
||||
int tile_idx = idx / 2;
|
||||
int quad_idx = idx % 2;
|
||||
|
||||
int col_offset = tile_idx * kInterleavedTileColumns;
|
||||
frag_ptr[idx] = access_ptr[col_offset / kElementsPerAccess + quad_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index byte_offset) const {
|
||||
|
||||
load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits<Element>::value);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset) const {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
|
||||
load_with_pointer_offset(frag, ref_.offset(coord_offset));
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index pointer_offset) const {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
|
||||
load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index byte_offset) const {
|
||||
|
||||
TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
|
||||
|
||||
load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits<Element>::value);
|
||||
}
|
||||
|
||||
/// Notify the iterator which k-group it is currently pointing to.
|
||||
///
|
||||
/// This does not advance the iterator. Rather, it overrides its internal
|
||||
/// tracking with constant-valued k-group index to enable the compiler to
|
||||
/// fold constants and achieve more efficient code.
|
||||
///
|
||||
/// This is used by some nontrivial permuted layouts.
|
||||
CUTLASS_DEVICE
|
||||
void set_kgroup_index(int k_group) {
|
||||
// no operation
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Data type of elements
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Interval between adjacent *MMA instructions (in units of MMA
|
||||
/// instructions)
|
||||
int OpDelta_>
|
||||
class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
Shape_,
|
||||
Operand::kA,
|
||||
Element_,
|
||||
cutlass::layout::RowMajor,
|
||||
InstructionShape_,
|
||||
OpDelta_,
|
||||
32
|
||||
> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner<
|
||||
Shape_, Operand::kA, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> {
|
||||
|
||||
public:
|
||||
using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner<
|
||||
Shape_, Operand::kA, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> ;
|
||||
|
||||
using TensorRef = typename Base::TensorRef;
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIterator(
|
||||
TensorRef const &ref,
|
||||
int lane_id
|
||||
): Base(ref, lane_id) { }
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Data type of elements
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Interval between adjacent *MMA instructions (in units of MMA
|
||||
/// instructions)
|
||||
int OpDelta_>
|
||||
class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
Shape_,
|
||||
Operand::kA,
|
||||
Element_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
InstructionShape_,
|
||||
OpDelta_,
|
||||
32
|
||||
> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter<
|
||||
Shape_, Operand::kA, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> {
|
||||
|
||||
public:
|
||||
using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter<
|
||||
Shape_, Operand::kA, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> ;
|
||||
|
||||
using TensorRef = typename Base::TensorRef;
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIterator(
|
||||
TensorRef const &ref,
|
||||
int lane_id
|
||||
): Base(ref, lane_id) { }
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Data type of elements
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Interval between adjacent *MMA instructions (in units of MMA
|
||||
/// instructions)
|
||||
int OpDelta_>
|
||||
class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
Shape_, Operand::kB, Element_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
InstructionShape_, OpDelta_, 32
|
||||
> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner<
|
||||
Shape_, Operand::kB, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> {
|
||||
|
||||
public:
|
||||
using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner<
|
||||
Shape_, Operand::kB, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_>;
|
||||
|
||||
using TensorRef = typename Base::TensorRef;
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIterator(
|
||||
TensorRef const &ref,
|
||||
int lane_id
|
||||
): Base(ref, lane_id) { }
|
||||
};
|
||||
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Data type of elements
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Interval between adjacent *MMA instructions (in units of MMA
|
||||
/// instructions)
|
||||
int OpDelta_>
|
||||
class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
Shape_, Operand::kB, Element_,
|
||||
cutlass::layout::RowMajor,
|
||||
InstructionShape_, OpDelta_, 32
|
||||
> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter<
|
||||
Shape_, Operand::kB, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> {
|
||||
|
||||
public:
|
||||
using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter<
|
||||
Shape_, Operand::kB, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_>;
|
||||
|
||||
using TensorRef = typename Base::TensorRef;
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
MmaVoltaTensorOpMultiplicandTileIterator(
|
||||
TensorRef const &ref,
|
||||
int lane_id
|
||||
): Base(ref, lane_id) { }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
#endif
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
@ -121,6 +122,12 @@ public:
|
||||
LongIndex(stride_[2] * coord.n());
|
||||
}
|
||||
|
||||
/// Returns the offset of a pitchlinear coordinate in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(PitchLinearCoord coord) const {
|
||||
return coord.contiguous() + LongIndex(coord.strided() * stride_[2]);
|
||||
}
|
||||
|
||||
/// Returns the logical coordinate (n, h, w, c) from a given offset in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord inverse(LongIndex index) const {
|
||||
@ -182,7 +189,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Mapping function for 4-D NCHW tensors.
|
||||
@ -424,6 +430,14 @@ public:
|
||||
LongIndex(stride_[2] * c_major);
|
||||
}
|
||||
|
||||
/// Returns the offset of a pitchlinear coordinate in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(PitchLinearCoord const &coord) const {
|
||||
return (coord.contiguous() % kInterleave) +
|
||||
LongIndex((coord.contiguous() / kInterleave) * stride_[2]) +
|
||||
LongIndex(coord.strided() * kInterleave);
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride stride() const {
|
||||
|
||||
@ -340,6 +340,134 @@ struct PitchLinearWarpRakedThreadMap {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous
|
||||
/// elements. Warps are arranged based on a stride.
|
||||
///
|
||||
/// This ThreadMap is used by tensor core kernels for NCxHWx layout.
|
||||
template <
|
||||
typename Shape_,
|
||||
int Threads,
|
||||
typename WarpThreadArrangement_,
|
||||
int ElementsPerAccess = 1
|
||||
>
|
||||
struct PitchLinearStridedWarpRakedThreadMap {
|
||||
|
||||
/// Tensor coordinate
|
||||
using TensorCoord = layout::PitchLinearCoord;
|
||||
|
||||
/// Tile shape
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = Threads;
|
||||
|
||||
using WarpThreadArrangement = WarpThreadArrangement_;
|
||||
|
||||
/// Extract vector length from Layout
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
/// Base ThreadMap
|
||||
using BaseThreadMap = PitchLinearWarpRakedThreadMap<
|
||||
Shape,
|
||||
kThreads,
|
||||
WarpThreadArrangement,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
/// Shape of access by each thread
|
||||
using ThreadAccessShape = typename BaseThreadMap::ThreadAccessShape;
|
||||
|
||||
|
||||
struct Detail {
|
||||
|
||||
using WarpThreadArrangement = WarpThreadArrangement_;
|
||||
|
||||
using WarpAccessIterations = typename BaseThreadMap::Detail::WarpAccessIterations;
|
||||
|
||||
static int const kWarpSize = BaseThreadMap::Detail::kWarpSize;
|
||||
|
||||
static int const kWarpCount = BaseThreadMap::Detail::kWarpCount;
|
||||
|
||||
using ShapeInAccesses = typename BaseThreadMap::Detail::ShapeInAccesses;
|
||||
|
||||
// Divide it into the number of warps, first partitioning the contiguous dimension then the
|
||||
// stride.
|
||||
static int const kWarpsContiguous =
|
||||
(WarpAccessIterations::kContiguous >= kWarpCount
|
||||
? kWarpCount
|
||||
: WarpAccessIterations::kContiguous);
|
||||
|
||||
static int const kWarpsStrided =
|
||||
(kWarpCount > WarpAccessIterations::kContiguous
|
||||
? kWarpCount / kWarpsContiguous
|
||||
: 1);
|
||||
|
||||
/// Arrangement of warps within a threadblock-scoped tile
|
||||
using WarpArrangement = layout::PitchLinearShape<
|
||||
kWarpsContiguous, kWarpsStrided
|
||||
>;
|
||||
|
||||
};
|
||||
|
||||
///< Iterations along each dimension (concept: PitchLinearShape)
|
||||
using Iterations = layout::PitchLinearShape<
|
||||
Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous,
|
||||
Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided
|
||||
>;
|
||||
|
||||
static_assert(Iterations::kCount,
|
||||
"Number of iterations must be non-zero");
|
||||
|
||||
///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
|
||||
using Delta = typename BaseThreadMap::Delta;
|
||||
|
||||
/// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
|
||||
CUTLASS_HOST_DEVICE
|
||||
static TensorCoord initial_offset(int thread_id) {
|
||||
|
||||
int warp_id = (thread_id / Detail::kWarpSize);
|
||||
int lane_id = (thread_id % Detail::kWarpSize);
|
||||
|
||||
//
|
||||
// compute warp-level offset
|
||||
//
|
||||
|
||||
// This is the shape of the entire area covered by a warp's memory access (in units of vectors)
|
||||
layout::PitchLinearCoord warp_footprint{
|
||||
Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
|
||||
Detail::WarpThreadArrangement::kStrided * Iterations::kStrided
|
||||
};
|
||||
|
||||
// This is the offset of a specific warp (in units of vectors)
|
||||
layout::PitchLinearCoord warp_offset{
|
||||
(warp_id % Detail::kWarpsContiguous),
|
||||
(warp_id / Detail::kWarpsContiguous)
|
||||
};
|
||||
|
||||
// This is the offset of a specific thread within a warp (units of vectors)
|
||||
layout::PitchLinearCoord thread_offset_in_warp{
|
||||
lane_id % Detail::WarpThreadArrangement::kContiguous,
|
||||
lane_id / Detail::WarpThreadArrangement::kContiguous
|
||||
};
|
||||
|
||||
// This is the offset of a thread within a threadblock tile (units of vectors)
|
||||
layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
|
||||
warp_footprint * warp_offset + thread_offset_in_warp;
|
||||
|
||||
// This is the offset of a thread within a threadblock tile (units of elements)
|
||||
layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
|
||||
thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess,
|
||||
thread_offset_in_threadblock_tile_vec.strided()
|
||||
};
|
||||
|
||||
return thread_offset_in_threadblock_tile_base;
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Transpose the existing ThreadMap. For example, interleaved layout is like
|
||||
/// congruous in the global memory and crosswise in the shared memory. We need
|
||||
/// to transpose the coordinates between two.
|
||||
|
||||
@ -500,7 +500,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization of PredicatedTileAccessIterator for pitch-linear data.
|
||||
/// Specialization of PredicatedTileAccessIterator for column-major data.
|
||||
///
|
||||
/// Satisfies: ForwardTileIteratorConcept |
|
||||
/// ReadableContiguousTileIteratorConcept |
|
||||
@ -676,7 +676,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::ColumnMajor,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization of PredicatedTileAccessIterator for pitch-linear data.
|
||||
/// Specialization of PredicatedTileAccessIterator for row-major data.
|
||||
///
|
||||
/// Satisfies: ForwardTileIteratorConcept |
|
||||
/// ReadableContiguousTileIteratorConcept |
|
||||
@ -852,8 +852,8 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::RowMajor,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization of PredicatedTileAccessIterator for interleaved data. It
|
||||
/// is mapped to the congruous layout.
|
||||
/// Specialization of PredicatedTileAccessIterator for column-major interleaved data.
|
||||
/// It is mapped to the congruous layout.
|
||||
///
|
||||
/// Satisfies: ForwardTileIteratorConcept |
|
||||
/// ReadableContiguousTileIteratorConcept |
|
||||
@ -1032,8 +1032,8 @@ class PredicatedTileAccessIterator<Shape_, Element_,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization of PredicatedTileAccessIterator for interleaved data. It
|
||||
/// is mapped to the congruous layout.
|
||||
/// Specialization of PredicatedTileAccessIterator for row-major interleaved data.
|
||||
// It is mapped to the congruous layout.
|
||||
///
|
||||
/// Satisfies: ForwardTileIteratorConcept |
|
||||
/// ReadableContiguousTileIteratorConcept |
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user