Compare commits
79 Commits
thakkarV-p
...
v3.2.2
| Author | SHA1 | Date | |
|---|---|---|---|
| 44c704eae8 | |||
| 6581237a48 | |||
| 5cd735c48e | |||
| 67ae8e0603 | |||
| 14f69bddc8 | |||
| 90d3b0fb18 | |||
| e0aaa3c3b3 | |||
| 8783c41851 | |||
| 6407bcdf0a | |||
| a77b2c9cb8 | |||
| 34bbadd3ff | |||
| 88c0d7c726 | |||
| e01b9b5029 | |||
| 34fd98056b | |||
| 3a8f57a3c8 | |||
| 6673df0e48 | |||
| 7618e9bfd8 | |||
| a88c41cf8d | |||
| 27de343535 | |||
| 2a9fa23e06 | |||
| 2e56cfabee | |||
| 3930f709ce | |||
| 7e5ee8b7bf | |||
| 2d9a557427 | |||
| 4575443d44 | |||
| a0d787b746 | |||
| d20f3a9542 | |||
| 8e85580859 | |||
| 146d314057 | |||
| f679663224 | |||
| e066ced33b | |||
| 9b923dd4c4 | |||
| f6d42f2dd0 | |||
| 473a67073e | |||
| 87349d3496 | |||
| fde824af21 | |||
| 7dbf423763 | |||
| 6f47420213 | |||
| 4638250469 | |||
| 7859fe322a | |||
| d3e72719b4 | |||
| b4ab501767 | |||
| f079619f5e | |||
| 13f413493a | |||
| 6fbc0d3380 | |||
| b97404837e | |||
| e2953d47c5 | |||
| 19c4a4815e | |||
| fcfbd23e26 | |||
| b250faccd3 | |||
| 24c8b7d8a2 | |||
| 7c04f95415 | |||
| 6f8596ce3f | |||
| fe2f491dd7 | |||
| df02482f1d | |||
| 180c5629bf | |||
| e36912f961 | |||
| 9a83bd3381 | |||
| 54bebe417d | |||
| 43cfbe0086 | |||
| 4a68cf748e | |||
| d572cc1aab | |||
| 9b8166e3f0 | |||
| e2d439ee7e | |||
| 0435979f59 | |||
| 2ba1ef10be | |||
| 0964bdb64c | |||
| ecbd24566c | |||
| 660a05f581 | |||
| bc36122c3f | |||
| 15d9d31f1f | |||
| 1eef5c3cf1 | |||
| 87070b6d51 | |||
| 77549ae6c8 | |||
| 42290f5d1c | |||
| 209faf7b94 | |||
| 6116706c96 | |||
| 2670b973dd | |||
| af332d4aa9 |
49
CHANGELOG.md
49
CHANGELOG.md
@ -1,5 +1,52 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2) (2023-10-25)
|
||||
* Fixes illegal memory access issue [1138](https://github.com/NVIDIA/cutlass/issues/1138) hit by FlashAttention tests in PyTorch.
|
||||
|
||||
## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22)
|
||||
* Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
|
||||
* SM80 EVT support in C++ and Python.
|
||||
* Other SM90 epilogue improvements.
|
||||
* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
|
||||
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details.
|
||||
* SM90 TF32 kernel improvements for all layouts.
|
||||
* SM90 rasterization direction support in the CUTLASS profiler.
|
||||
* Improvement for CUTLASS profiler build times.
|
||||
* Remove Python-C++ bindings.
|
||||
|
||||
## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03)
|
||||
|
||||
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
|
||||
* New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
|
||||
* [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
|
||||
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
|
||||
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
|
||||
* New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here.
|
||||
* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0.
|
||||
* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
|
||||
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
|
||||
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
|
||||
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
|
||||
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
|
||||
* An [example](examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
|
||||
* Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
|
||||
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
|
||||
* Performance optimizations for the [*warp-specialized persistent ping-pong*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
|
||||
* Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
* [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
* [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
|
||||
* [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
|
||||
* [Row Broadcast](include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
|
||||
* The GitHub branch is renamed from `master` to `main` in this release.
|
||||
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
|
||||
* [CuTe](/media/docs/cute/00_quickstart.md), a [new core library and backend](/include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
@ -57,7 +104,7 @@
|
||||
* [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](/tools/library/scripts/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [Python-based instance emitter](/python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
* Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3
|
||||
* [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
|
||||
|
||||
241
CMakeLists.txt
241
CMakeLists.txt
@ -26,7 +26,8 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
|
||||
cmake_minimum_required(VERSION 3.19 FATAL_ERROR)
|
||||
cmake_policy(SET CMP0112 NEW)
|
||||
|
||||
if(cutlass_LOADED)
|
||||
# If CUTLASS has been previously fetched and loaded, don't do it again.
|
||||
@ -39,7 +40,7 @@ endif()
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
|
||||
|
||||
project(CUTLASS VERSION 3.0.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 3.2.2 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 11.3)
|
||||
@ -58,6 +59,8 @@ endif()
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
|
||||
################################################################################
|
||||
|
||||
#
|
||||
# CUTLASS 3.x requires C++17
|
||||
#
|
||||
@ -79,16 +82,41 @@ endif()
|
||||
|
||||
message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}")
|
||||
|
||||
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
|
||||
# 0 - Sanity, 1 - Release-Quality, 2 - Exhaustive
|
||||
|
||||
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
# Install cutlass_library Python package
|
||||
execute_process(
|
||||
WORKING_DIRECTORY ${CUTLASS_DIR}/python
|
||||
COMMAND ${Python3_EXECUTABLE} ${CUTLASS_DIR}/python/setup_library.py develop --user
|
||||
RESULT_VARIABLE cutlass_lib_GENERATOR_INSTALL_RESULT
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log
|
||||
)
|
||||
|
||||
if(NOT cutlass_lib_GENERATOR_INSTALL_RESULT EQUAL 0)
|
||||
message(FATAL_ERROR "Error installing cutlass_library package. See ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log")
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
|
||||
|
||||
if(CUTLASS_ENABLE_HEADERS_ONLY)
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT ON)
|
||||
set(CUTLASS_ENABLE_LIBRARY_INIT OFF)
|
||||
set(CUTLASS_ENABLE_TESTS_INIT OFF)
|
||||
else()
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT ON)
|
||||
set(CUTLASS_ENABLE_LIBRARY_INIT ON)
|
||||
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ON)
|
||||
else()
|
||||
set(CUTLASS_ENABLE_TESTS_INIT OFF)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CUTLASS_TEST_UNIT_ENABLE_WARNINGS OFF CACHE BOOL "Enable warnings on waived unit tests.")
|
||||
@ -97,19 +125,11 @@ set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable C
|
||||
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
|
||||
set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUTLASS Library")
|
||||
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler")
|
||||
set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Proformance")
|
||||
|
||||
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}})
|
||||
else()
|
||||
set(CUTLASS_ENABLE_TESTS_INIT OFF)
|
||||
endif()
|
||||
set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Performance")
|
||||
|
||||
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
|
||||
|
||||
if (CUTLASS_ENABLE_TESTS)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
|
||||
endif()
|
||||
set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests")
|
||||
################################################################################
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
@ -124,6 +144,17 @@ endif()
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
|
||||
|
||||
# Find unsupported and deprecated compute capabilities
|
||||
if (CUTLASS_NVCC_ARCHS_SUPPORTED)
|
||||
set(CUTLASS_NVCC_ARCHS_UNSUPPORTED ${CUTLASS_NVCC_ARCHS})
|
||||
list(REMOVE_ITEM CUTLASS_NVCC_ARCHS_UNSUPPORTED ${CUTLASS_NVCC_ARCHS_SUPPORTED})
|
||||
if (CUTLASS_NVCC_ARCHS_UNSUPPORTED)
|
||||
message(WARNING "Using unsupported or deprecated compute capabilities ${CUTLASS_NVCC_ARCHS_UNSUPPORTED}. Support may be removed in future versions.")
|
||||
endif()
|
||||
else()
|
||||
message(WARNING "No supported compute capabilities for CUDA ${CUDA_VERSION}.")
|
||||
endif()
|
||||
|
||||
# Special policy introduced in CMake 3.13
|
||||
if (POLICY CMP0076)
|
||||
cmake_policy(SET CMP0076 NEW)
|
||||
@ -161,8 +192,8 @@ if(WIN32)
|
||||
endif()
|
||||
|
||||
if (WIN32)
|
||||
# Enable more warnings and treat as errors
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX)
|
||||
# Enable more warnings. Add "-Xcompiler=/WX" to enable warnings as errors.
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3)
|
||||
|
||||
# Disable warning on Unicode characters
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819)
|
||||
@ -185,15 +216,16 @@ set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
|
||||
|
||||
################################################################################
|
||||
#
|
||||
# CUTLASS generator cmake configuration
|
||||
#
|
||||
|
||||
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
|
||||
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
|
||||
set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
|
||||
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.")
|
||||
|
||||
# Test Levels L0, L1, L2
|
||||
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
|
||||
################################################################################
|
||||
|
||||
set(CUTLASS_TEST_ENABLE_CACHED_RESULTS ON CACHE BOOL "Enable caching and reuse of test results in unit tests")
|
||||
|
||||
@ -213,6 +245,8 @@ if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
#
|
||||
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
|
||||
#
|
||||
@ -262,6 +296,8 @@ if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
|
||||
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
|
||||
# MSVC flow handles caching already, but for other generators we handle it here.
|
||||
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
|
||||
@ -287,9 +323,10 @@ if (CUTLASS_ENABLE_OPENMP_TESTS)
|
||||
message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-Wconversion>)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-fno-strict-aliasing>)
|
||||
if(UNIX)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-Wconversion)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing)
|
||||
endif()
|
||||
|
||||
# Don't leak lineinfo in release builds
|
||||
if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
|
||||
@ -352,6 +389,28 @@ if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
|
||||
cmake_policy(SET CMP0104 NEW)
|
||||
endif()
|
||||
|
||||
if (MSVC)
|
||||
|
||||
# MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard
|
||||
# because MSVC is not a completely compliant implementation. This option forces MSVC to use the
|
||||
# appropriate value given the requested --std option. This fixes a compilation issue mismatch
|
||||
# between GCC/Clang and MSVC.
|
||||
#
|
||||
# error : a constexpr function cannot have a nonliteral return type "dim3"
|
||||
#
|
||||
# See https://developercommunity.visualstudio.com/t/msvc-incorrectly-defines-cplusplus/139261
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus")
|
||||
|
||||
endif()
|
||||
|
||||
# Some tests require this build option in order to link.
|
||||
if (MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /bigobj")
|
||||
endif()
|
||||
|
||||
function(cutlass_apply_cuda_gencode_flags TARGET)
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
@ -466,7 +525,8 @@ endfunction()
|
||||
|
||||
# GLOB for CUTLASS header files. Should we use a static list instead?
|
||||
file(GLOB_RECURSE CUTLASS_INCLUDE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} include/cutlass/*.h)
|
||||
file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h)
|
||||
file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h include/cutlass/*.hpp include/cutlass/*.inl)
|
||||
file(GLOB_RECURSE CUTLASS_CUTE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cute/*.h*)
|
||||
file(GLOB_RECURSE CUTLASS_NVRTC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/test test/unit/nvrtc/kernel/*.h)
|
||||
|
||||
###################################################################################################
|
||||
@ -526,11 +586,17 @@ target_include_directories(
|
||||
$<INSTALL_INTERFACE:include>
|
||||
$<BUILD_INTERFACE:${CUTLASS_INCLUDE_DIR}>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/include>
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
|
||||
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/examples>
|
||||
)
|
||||
|
||||
# Mark CTK headers as system to supress warnings from them
|
||||
target_include_directories(
|
||||
CUTLASS
|
||||
SYSTEM INTERFACE
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY
|
||||
${CUTLASS_INCLUDE_DIR}/
|
||||
@ -587,6 +653,11 @@ endif()
|
||||
|
||||
include(CTest)
|
||||
enable_testing()
|
||||
|
||||
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
|
||||
endif()
|
||||
|
||||
if (NOT TARGET test_all)
|
||||
add_custom_target(test_all)
|
||||
endif()
|
||||
@ -623,7 +694,7 @@ endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.config.cmake)
|
||||
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.configure.cmake)
|
||||
set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "")
|
||||
|
||||
function(cutlass_add_executable_tests NAME TARGET)
|
||||
@ -637,14 +708,16 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
# DEPENDS: A list of targets or files on which this test is dependent.
|
||||
# DEPENDEES: A list of targets which should depend on this test.
|
||||
# TEST_COMMAND_OPTIONS: A list of variables (i.e. by reference params) which contain command line arguments
|
||||
# to pass to the test executable. A unique test with suffix _0, _1, ... is generated for each set of
|
||||
# to pass to the test executable. A unique test is generated for each set of
|
||||
# options given. If this option is not used, a single test with no arguments is generated.
|
||||
# TEST_COMMAND_OPTIONS_PREFIX: If provided, is added as a prefix to each TEST_COMMAND_OPTIONS value for
|
||||
# generating the full variable name to be referenced.
|
||||
# RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed
|
||||
# test results to speed up test runtime.
|
||||
#
|
||||
|
||||
set(options DISABLE_EXECUTABLE_INSTALL_RULE)
|
||||
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE)
|
||||
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX)
|
||||
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
@ -652,6 +725,9 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
set(__DISABLE_TESTS OFF)
|
||||
endif()
|
||||
|
||||
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
|
||||
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})
|
||||
|
||||
if (__RESULT_CACHE_FILE)
|
||||
|
||||
add_custom_command(
|
||||
@ -688,7 +764,6 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
endif()
|
||||
|
||||
list(LENGTH __TEST_COMMAND_OPTIONS CMD_COUNT)
|
||||
set(CMD_IDX 0)
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
add_custom_target(${NAME} DEPENDS ${TARGET} ${__DEPENDS})
|
||||
@ -697,12 +772,22 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
foreach(CMD_OPTIONS ${__TEST_COMMAND_OPTIONS})
|
||||
if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
set(_INLINE_PER_TEST_CODE)
|
||||
|
||||
file(READ "${PROJECT_SOURCE_DIR}/cmake/CTestTestfile.test.configure.cmake" _INLINE_PER_TEST_CODE_TEMPLATE)
|
||||
|
||||
endif()
|
||||
|
||||
set(TEST_GROUP_NAME ${NAME})
|
||||
|
||||
foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS)
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
set(TEST_NAME ${NAME}_${CMD_IDX})
|
||||
string(TOLOWER "${NAME}_${CMD_OPTIONS_VAR}" TEST_NAME)
|
||||
else()
|
||||
set(TEST_NAME ${NAME})
|
||||
string(TOLOWER "${NAME}" TEST_NAME)
|
||||
endif()
|
||||
|
||||
# The following rigmarole is needed to deal with spaces and possible quotes in
|
||||
@ -711,14 +796,14 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
# preserves any quotes. Note, they have to be in this order for it to work for
|
||||
# all the use cases below.
|
||||
|
||||
set(CMD_OPTIONS ${${CMD_OPTIONS}})
|
||||
list(JOIN CMD_OPTIONS " " TEST_COMMAND_OPTIONS)
|
||||
separate_arguments(CMD_OPTIONS)
|
||||
|
||||
set(TEST_COMMAND_OPTIONS ${${__TEST_COMMAND_OPTIONS_PREFIX}${CMD_OPTIONS_VAR}})
|
||||
list(JOIN TEST_COMMAND_OPTIONS " " TEST_COMMAND_OPTIONS)
|
||||
separate_arguments(TEST_COMMAND_OPTIONS)
|
||||
|
||||
add_custom_target(
|
||||
${TEST_NAME}
|
||||
COMMAND
|
||||
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
|
||||
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${TEST_COMMAND_OPTIONS}
|
||||
DEPENDS
|
||||
${TARGET}
|
||||
)
|
||||
@ -731,41 +816,48 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
add_dependencies(${DEPENDEE} ${TEST_NAME})
|
||||
endforeach()
|
||||
|
||||
add_test(
|
||||
NAME c${TEST_NAME}
|
||||
COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
|
||||
)
|
||||
set(TEST_NAME c${TEST_NAME})
|
||||
string(CONFIGURE "${_INLINE_PER_TEST_CODE_TEMPLATE}" _TEST_CODE @ONLY)
|
||||
string(APPEND _INLINE_PER_TEST_CODE "${_TEST_CODE}")
|
||||
|
||||
set_tests_properties(c${TEST_NAME} PROPERTIES DISABLED ${__DISABLE_TESTS})
|
||||
endforeach()
|
||||
|
||||
# To run the tests from an install package with tests enabled, we need to generate test files
|
||||
# that don't rely on the current directory structure in build.
|
||||
|
||||
set(TEST_NAME c${NAME})
|
||||
set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME})
|
||||
file(MAKE_DIRECTORY ${TEST_GEN_DIR})
|
||||
|
||||
set(TEST_EXE_PATH $<TARGET_FILE:${TARGET}>)
|
||||
set(TEST_USE_EXTENDED_FORMAT ON)
|
||||
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY)
|
||||
|
||||
set(TEST_EXE_PATH $<TARGET_FILE_NAME:${TARGET}>)
|
||||
set(TEST_USE_EXTENDED_FORMAT OFF) # ctest does not support extended add_test format.
|
||||
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" @ONLY)
|
||||
|
||||
# The following line imports the tests for immediate run via `make test`.
|
||||
|
||||
include(${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake)
|
||||
|
||||
set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/${TEST_NAME}/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "")
|
||||
|
||||
if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
# To run the tests from an install package with tests enabled, we need to generate test files
|
||||
# that don't rely on the current directory structure in build.
|
||||
file(GENERATE
|
||||
OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake"
|
||||
INPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in"
|
||||
)
|
||||
|
||||
set(TEST_NAME c${TEST_NAME})
|
||||
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
|
||||
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})
|
||||
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake" @ONLY)
|
||||
install(
|
||||
FILES "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake"
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/${TEST_NAME}
|
||||
RENAME CTestTestfile.${TEST_NAME}.cmake
|
||||
)
|
||||
|
||||
file(GENERATE
|
||||
OUTPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake"
|
||||
INPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake"
|
||||
)
|
||||
|
||||
install(
|
||||
FILES "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake"
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/
|
||||
)
|
||||
|
||||
set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "")
|
||||
|
||||
endif()
|
||||
|
||||
math(EXPR CMD_IDX "${CMD_IDX} + 1")
|
||||
|
||||
endforeach()
|
||||
|
||||
endfunction()
|
||||
|
||||
if (CUTLASS_ENABLE_TOOLS)
|
||||
@ -774,6 +866,7 @@ if (CUTLASS_ENABLE_TOOLS)
|
||||
add_dependencies(test_all test_profiler)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
add_dependencies(test_all test_examples)
|
||||
@ -781,38 +874,27 @@ endif()
|
||||
|
||||
if (CUTLASS_ENABLE_TESTS)
|
||||
add_subdirectory(test)
|
||||
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
|
||||
add_dependencies(test_all test_unit)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/cmake")
|
||||
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest")
|
||||
|
||||
file(WRITE "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "# Generated File\n")
|
||||
file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n")
|
||||
foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES})
|
||||
file(APPEND "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "include(${GENERATED_FILE})\n")
|
||||
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "include(${GENERATED_FILE})\n")
|
||||
endforeach()
|
||||
|
||||
install(
|
||||
FILES "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake"
|
||||
FILES "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake"
|
||||
DESTINATION "${CUTLASS_TEST_INSTALL_PREFIX}/"
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
#? install(
|
||||
#? FILES ${CMAKE_BINARY_DIR}/CTestTestfile.cmake
|
||||
#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
|
||||
#? )
|
||||
#?
|
||||
#? install(
|
||||
#? DIRECTORY
|
||||
#? ${CMAKE_BINARY_DIR}/tools
|
||||
#? ${CMAKE_BINARY_DIR}/test
|
||||
#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
|
||||
#? FILES_MATCHING PATTERN "CTestTestfile.cmake"
|
||||
#? )
|
||||
|
||||
################################################################################
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
@ -838,3 +920,4 @@ install(
|
||||
################################################################################
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassPackageConfig.cmake)
|
||||
|
||||
|
||||
15
CUDA.cmake
15
CUDA.cmake
@ -76,6 +76,7 @@ find_library(
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib/x86_64-linux-gnu
|
||||
lib/x64
|
||||
lib64
|
||||
lib
|
||||
@ -120,6 +121,7 @@ find_library(
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib/x86_64-linux-gnu
|
||||
lib/x64
|
||||
lib64
|
||||
lib
|
||||
@ -226,7 +228,14 @@ else()
|
||||
endif()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation")
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
|
||||
|
||||
if (MSVC)
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 8)
|
||||
else()
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 16)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT} CACHE STRING "Batch size for unified source files")
|
||||
|
||||
function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
|
||||
@ -296,10 +305,10 @@ function(cutlass_add_library NAME)
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
add_library(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
|
||||
else()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
|
||||
endif()
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
|
||||
@ -2,12 +2,22 @@
|
||||
|
||||
## 2023
|
||||
|
||||
- ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023.
|
||||
|
||||
- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023.
|
||||
|
||||
- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023.
|
||||
|
||||
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
|
||||
|
||||
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
|
||||
|
||||
## 2022
|
||||
|
||||
- ["GPU Load Balancing"](https://arxiv.org/abs/2212.08964). Muhammad Osama. _Doctoral dissertation, University of California, Davis_, December 2022.
|
||||
|
||||
- ["Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"](https://arxiv.org/abs/2211.10017). Young Jin Kim, Rawn Henry, Raffy Fahim, Hany Hassan Awadalla. _Proceedings of the Third Workshop on Simple and Efficient Natural Language Processing_, December 2022.
|
||||
|
||||
- ["Bolt: Bridging the Gap between Auto-tuners and Hardware-native Performance"](https://arxiv.org/abs/2110.15238). Jiarong Xing, Leyuan Wang, Shang Zhang, Jack Chen, Ang Chen, Yibo Zhu. _Proceedings of the 5th MLSys Conference_, August 2022.
|
||||
|
||||
- ["Recovering single precision accuracy from Tensor Cores while surpassing the FP32 theoretical peak performance"](https://arxiv.org/abs/2203.03341). Hiroyuki Ootomo, Rio Yokota. _International Journal of High Performance Computing_, March 2022.
|
||||
@ -18,7 +28,7 @@
|
||||
|
||||
- ["Arithmetic-intensity-guided fault tolerance for neural network inference on GPUs"](https://dl.acm.org/doi/abs/10.1145/3458817.3476184). Jack Kosaian, K. V. Rashmi. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2021.
|
||||
|
||||
- ["Real-time Neural Radiance Caching for Path Tracing"](https://d1qx31qr3h6wln.cloudfront.net/publications/paper_4.pdf). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
|
||||
- ["Real-time Neural Radiance Caching for Path Tracing"](https://dl.acm.org/doi/abs/10.1145/3450626.3459812). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
|
||||
|
||||
## 2020
|
||||
|
||||
|
||||
69
README.md
69
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 3.0
|
||||
# CUTLASS 3.2
|
||||
|
||||
_CUTLASS 3.0 - January 2023_
|
||||
_CUTLASS 3.2 - August 2023_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
|
||||
@ -31,33 +31,39 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
CUTLASS 3.0 introduces a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations.
|
||||
|
||||
The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
|
||||
|
||||
CUTLASS 3.0 adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
|
||||
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md).
|
||||
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
|
||||
# What's New in CUTLASS 3.0
|
||||
# What's New in CUTLASS 3.2
|
||||
|
||||
CUTLASS 3.0, as the next major version of the CUTLASS API, brings with it CuTe, a new programming model and backend designed for massively parallel heterogenous agents. Using CuTe, CUTLASS 3.0 provides implementations of GEMM kernels for the NVIDIA Hopper architecture.
|
||||
CUTLASS 3.2.0 is an update to CUTLASS adding:
|
||||
- New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm).
|
||||
- New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
|
||||
- [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
|
||||
- Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
|
||||
- Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
- [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
|
||||
- New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here.
|
||||
- Support for Windows (MSVC) builds.
|
||||
|
||||
- [CuTe-based layouts and layout algebra](/media/docs/cute/00_quickstart.md)
|
||||
- [A new GEMM template API](/media/docs/gemm_api_3x.md) that eschews the architecture-centric hierarchy of 2.x in favour of a new conceptual framing. Read more in the [3.0 design documentation](/media/docs/cutlass_3x_design.md).
|
||||
- Support for 4th generation Hopper Tensor Core instructions (WGMMA) through CuTe.
|
||||
- Support for Hopper asynchronous Tensor Memory Accelerator (TMA) instructions and associated transaction barriers through CuTe.
|
||||
- New warp-specialized GEMM kernels targeting Hopper TMA + WGMMA for speed-of-light GEMMs.
|
||||
- New warp-specialized persistent GEMM kernels targeting Hopper TMA + WGMMA.
|
||||
- Support for CUDA Threadblock Clusters and programmatic TMA multicast for greater execution and data locality.
|
||||
- A new way to instantiate default GEMM kernels using `CollectiveBuilder`s that supersede the 2.x `DefaultXConfiguration` types in favour a metaprogramming based kernel generator functionality. See [example 49](/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu).
|
||||
- Extensions to the CUTLASS library and profiler to support CUTLASS 3.0 Hopper kernels, and a new format
|
||||
for kernel procedural names.
|
||||
- *Announcement*: CUTLASS plans to rename the GitHub branch `master` to `main` with a future release.
|
||||
CUTLASS 3.2.1 is an update to CUTLASS adding:
|
||||
- Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
|
||||
- SM80 EVT support in C++ and Python.
|
||||
- Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
|
||||
- Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details.
|
||||
- SM90 TF32 kernel improvements for all layouts.
|
||||
- SM90 rasterization direction support in the CUTLASS profiler.
|
||||
- Improvement for CUTLASS profiler build times.
|
||||
|
||||
## New architecture, compiler, and CUDA Toolkit requirements
|
||||
CUTLASS 3.2.2 is a minor update to CUTLASS adding:
|
||||
- Bug fix for illegal memory access issue hit by Flash Attention tests in PyTorch. See [1138](https://github.com/NVIDIA/cutlass/issues/1138) for details.
|
||||
|
||||
Minimum requirements:
|
||||
|
||||
@ -65,7 +71,7 @@ Minimum requirements:
|
||||
- Compiler: Must support at least C++17
|
||||
- CUDA Toolkit version: 11.4
|
||||
|
||||
CUTLASS 3.0 *removes support* for the following:
|
||||
Starting from CUTLASS 3.0, CUTLASS removed support for the following:
|
||||
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
@ -76,7 +82,7 @@ CUTLASS 3.0 *removes support* for the following:
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-3.0-gemm-peak-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.1-gemm-peak-performance.png></p>
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit peak performance comparable to cuBLAS for scalar GEMM
|
||||
@ -87,20 +93,21 @@ an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) (NVIDIA Ampere
|
||||
and an [NVIDIA A40](https://www.nvidia.com/en-us/data-center/a40/) (NVIDIA Ampere architecture).
|
||||
CUTLASS 3.0 was compiled with the [CUDA 12.0 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
Tensor Core operations are implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and
|
||||
[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-2.9-implicit-gemm-performance.png></p>
|
||||
|
||||
When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad)
|
||||
kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/)
|
||||
as shown in the above figure. Tensor Core operations are still implemented using CUDA's
|
||||
as shown in the above figure. Tensor Core operations are implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
performs best when built with the [**CUDA 12.0 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, and CUDA 11.8.
|
||||
performs best when built with the [**CUDA 12.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0 and CUDA 12.1.
|
||||
|
||||
## Operating Systems
|
||||
We have tested the following environments.
|
||||
@ -110,8 +117,10 @@ We have tested the following environments.
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
| Ubuntu 22.04 | GCC 11.2.0 |
|
||||
| Windows 10.0 | Visual Studio 2019 v16.11.27 |
|
||||
|
||||
Note: We plan to add Windows (MSVC) & Clang compiler support soon.
|
||||
Note: We plan to add Clang compiler support soon.
|
||||
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.
|
||||
|
||||
## Hardware
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs.
|
||||
@ -131,9 +140,9 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
|
||||
|
||||
## Target Architecture
|
||||
|
||||
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduces the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12.0 or 11.8, the kernel is expected to fail with a runtime error.
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12 or 11.8, the kernel is expected to fail with a runtime error.
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
@ -178,7 +187,8 @@ CUTLASS is a header-only template library and does not need to be built to be us
|
||||
projects. Client applications should target CUTLASS's `include/` directory in their include
|
||||
paths.
|
||||
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake starting version 3.12.
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake.
|
||||
The minimum version of CMake is given in the [Quickstart guide](media/docs/quickstart.md).
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
@ -514,7 +524,7 @@ reference_device: Passed
|
||||
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
||||
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
||||
- [GEMM CMake Examples](media/docs/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM conovlution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
|
||||
|
||||
|
||||
@ -558,4 +568,3 @@ SPDX-License-Identifier: BSD-3-Clause
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -1,21 +0,0 @@
|
||||
# Generated file
|
||||
|
||||
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
else()
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
|
||||
endif()
|
||||
|
||||
if (NOT "@TEST_EXE_DIR@" STREQUAL "")
|
||||
set(TEST_EXE_PATH @TEST_EXE_DIR@/@TEST_EXE@)
|
||||
else()
|
||||
set(TEST_EXE_PATH @TEST_EXE@)
|
||||
endif()
|
||||
|
||||
add_test("@TEST_NAME@" ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
|
||||
if (NOT "@TEST_EXE_WORKING_DIRECTORY@" STREQUAL "")
|
||||
set_tests_properties("@TEST_NAME@" PROPERTIES WORKING_DIRECTORY "@TEST_EXE_WORKING_DIRECTORY@")
|
||||
endif()
|
||||
|
||||
set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)
|
||||
14
cmake/CTestTestfile.configure.cmake
Normal file
14
cmake/CTestTestfile.configure.cmake
Normal file
@ -0,0 +1,14 @@
|
||||
# Generated file
|
||||
|
||||
set(TEST_EXE_PATH @TEST_EXE_PATH@)
|
||||
set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@)
|
||||
set(CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT @TEST_USE_EXTENDED_FORMAT@)
|
||||
|
||||
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
else()
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
|
||||
endif()
|
||||
|
||||
@_INLINE_PER_TEST_CODE@
|
||||
|
||||
15
cmake/CTestTestfile.test.configure.cmake
Normal file
15
cmake/CTestTestfile.test.configure.cmake
Normal file
@ -0,0 +1,15 @@
|
||||
if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT)
|
||||
# The longform/extended format allows generator expressions to be
|
||||
# expanded property and is useful in contexts where the files need
|
||||
# to be immediately included into being-processed cmake code.
|
||||
add_test(NAME @TEST_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
else()
|
||||
add_test(@TEST_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
endif()
|
||||
|
||||
if (TEST_EXE_WORKING_DIRECTORY)
|
||||
set_tests_properties(@TEST_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}")
|
||||
endif()
|
||||
|
||||
set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)
|
||||
|
||||
@ -2,6 +2,11 @@ get_filename_component(NvidiaCutlass_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
|
||||
|
||||
include(CMakeFindDependencyMacro)
|
||||
|
||||
if(NOT TARGET nvidia::cutlass::CUTLASS)
|
||||
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
|
||||
if(TARGET nvidia::cutlass::CUTLASS)
|
||||
return()
|
||||
endif()
|
||||
|
||||
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
|
||||
|
||||
# For backward compatibility with the old name
|
||||
add_library(cutlass_lib ALIAS cutlass_library)
|
||||
|
||||
@ -9,7 +9,7 @@ endif()
|
||||
FetchContent_Declare(
|
||||
googletest
|
||||
GIT_REPOSITORY https://github.com/google/googletest.git
|
||||
GIT_TAG 0fe9660
|
||||
GIT_TAG v1.13.0
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(googletest)
|
||||
|
||||
@ -291,8 +291,8 @@ int run() {
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue>
|
||||
int32_t,
|
||||
int32_t>
|
||||
gemm_device;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
@ -355,4 +355,3 @@ int main() {
|
||||
|
||||
return run();
|
||||
}
|
||||
|
||||
|
||||
@ -143,7 +143,6 @@ compare if the output from CUTLASS kernel is same as the reference implicit GEMM
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = int32_t; // Data type of accumulator
|
||||
@ -555,6 +554,7 @@ Result profile_convolution(Options const &options) {
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
ElementOutput,
|
||||
cutlass::NumericConverterClamp<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
@ -674,7 +674,6 @@ Result profile_convolution(Options const &options) {
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
@ -761,11 +760,7 @@ int main(int argc, char const **args) {
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
|
||||
@ -27,7 +27,10 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
#
|
||||
# This example depends on the CUTLASS Library
|
||||
#
|
||||
if (CUTLASS_ENABLE_LIBRARY)
|
||||
|
||||
# Planar Complex GEMM example
|
||||
cutlass_example_add_executable(
|
||||
@ -35,11 +38,6 @@ cutlass_example_add_executable(
|
||||
planar_complex.cu
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# This example depends on the CUTLASS Library
|
||||
#
|
||||
|
||||
target_link_libraries(
|
||||
10_planar_complex
|
||||
PRIVATE
|
||||
@ -48,3 +46,4 @@ target_link_libraries(
|
||||
cuda
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
@ -27,7 +27,10 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
#
|
||||
# This example depends on the CUTLASS Library
|
||||
#
|
||||
if (CUTLASS_ENABLE_LIBRARY)
|
||||
|
||||
# Planar Complex Array GEMM example
|
||||
cutlass_example_add_executable(
|
||||
@ -35,11 +38,6 @@ cutlass_example_add_executable(
|
||||
planar_complex_array.cu
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# This example depends on the CUTLASS Library
|
||||
#
|
||||
|
||||
target_link_libraries(
|
||||
11_planar_complex_array
|
||||
PRIVATE
|
||||
@ -48,3 +46,4 @@ target_link_libraries(
|
||||
cuda
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
@ -64,6 +64,7 @@ endforeach()
|
||||
foreach(FUSION_GEMM_EXAMPLE
|
||||
fused_two_gemms_f16_sm75_rf
|
||||
fused_two_gemms_f16_sm75_shmem
|
||||
fused_two_gemms_grouped_f16_sm80_rf
|
||||
fused_two_gemms_f16_sm80_rf
|
||||
fused_two_gemms_f16_sm80_shmem
|
||||
fused_two_gemms_s8_sm75_rf
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
# Introduction
|
||||
|
||||
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
|
||||
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_fusion.png></p>
|
||||
|
||||
When running two unfused GEMM/Conv operations, each operation loads one input
|
||||
activation matrix, one weight matrix (or filter matrix) from the memory and then
|
||||
When running two unfused GEMM/Conv operations, each operation loads one input
|
||||
activation matrix, one weight matrix (or filter matrix) from the memory and then
|
||||
stores the result activation matrix back to the memory.
|
||||
|
||||
When the two GEMM/Conv operations are fused together, the mainloops of the two
|
||||
@ -27,10 +27,10 @@ In order to run two GEMM/Convs in a single kernel, the example requires the same
|
||||
threadblocks are used across 2 GEMMs/Convs. This also ensures the same threadblock tile M across
|
||||
2 GEMMs/Convs.
|
||||
|
||||
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
|
||||
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
|
||||
input activation, the example enforces the following two constraints:
|
||||
|
||||
- thread_block_tile_N = problem_N
|
||||
- thread_block_tile_N = problem_N
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_block_resident_fusion.png></p>
|
||||
|
||||
@ -39,7 +39,7 @@ addition to its own input activation tile. Therefore the input activation tile o
|
||||
2nd GEMM/Conv only depends on the output activation tile of the 1st GEMM/Conv, and the
|
||||
operation can be fully block-resident.
|
||||
|
||||
- warp_tile_N = thread_block_tile_N
|
||||
- warp_tile_N = thread_block_tile_N
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_rf_resident_fusion.png></p>
|
||||
|
||||
@ -82,7 +82,7 @@ threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem`
|
||||
|
||||
|
||||
|
||||
# Copyright
|
||||
|
||||
|
||||
@ -42,6 +42,7 @@
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
@ -77,9 +78,9 @@ struct B2bNonFusedGemmRun
|
||||
//
|
||||
|
||||
B2bNonFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
@ -88,7 +89,7 @@ struct B2bNonFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -96,7 +97,7 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
@ -129,62 +130,62 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
ElementCompute,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
ElementCompute,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
||||
|
||||
|
||||
@ -270,13 +271,13 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
@ -312,32 +313,32 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
@ -349,7 +350,7 @@ struct B2bNonFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -362,7 +363,7 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
@ -399,9 +400,9 @@ struct B2bFusedGemmRun
|
||||
//
|
||||
|
||||
B2bFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
@ -412,7 +413,7 @@ struct B2bFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -420,11 +421,11 @@ struct B2bFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
@ -453,70 +454,90 @@ struct B2bFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
|
||||
// batch_count is used as split-k when mode is kGemm according
|
||||
// to the GemmUniversal interface
|
||||
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_A0 = 0,
|
||||
int64_t batch_stride_B0 = 0,
|
||||
int64_t batch_stride_C0 = 0,
|
||||
int64_t batch_stride_B1 = 0,
|
||||
int64_t batch_stride_C1 = 0,
|
||||
int64_t batch_stride_D1 = 0,
|
||||
int64_t batch_stride_Bias0 = 0,
|
||||
int64_t batch_stride_Scale0 = 0,
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
|
||||
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
||||
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
@ -554,6 +575,7 @@ struct B2bFusedGemmRun
|
||||
//
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
mode,
|
||||
problem_size_0,
|
||||
problem_size_1,
|
||||
tensor_A0.device_ref(),
|
||||
@ -564,8 +586,16 @@ struct B2bFusedGemmRun
|
||||
tensor_B1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1,
|
||||
batch_stride_Bias0,
|
||||
batch_stride_Scale0,
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
batch_count,
|
||||
};
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
@ -618,32 +648,31 @@ struct B2bFusedGemmRun
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator
|
||||
>(
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_A0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ElementAccumulator(0), //beta = 0
|
||||
reference_Z0.device_ref(),
|
||||
reference_Z0.device_ref(),
|
||||
ElementAccumulator(0)
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_C0,
|
||||
batch_stride_C0
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
cutlass::reference::device::TensorScaleBiasGemmBatched<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||
> (
|
||||
@ -652,25 +681,45 @@ struct B2bFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref()
|
||||
tensor_Bias0.device_ref(),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_C0,
|
||||
batch_stride_Scale0,
|
||||
batch_stride_Bias0
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
alpha1, //intermediate alpha=1
|
||||
reference_D0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B1.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
beta1, //beta = 0
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
reference_D1.device_ref(),
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
@ -680,7 +729,7 @@ struct B2bFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -694,7 +743,7 @@ struct B2bFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
|
||||
450
examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h
Normal file
450
examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h
Normal file
@ -0,0 +1,450 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Containers for running grouped back-to-back GEMMs
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
if((val1) <= (val2)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
||||
#define CHECK_TRUE(val) \
|
||||
if(!(val)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename B2bGemm_>
|
||||
struct B2bFusedGroupedGemmRun
|
||||
{
|
||||
|
||||
using B2bGemm = B2bGemm_;
|
||||
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
|
||||
using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
B2bFusedGroupedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes_0,
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
using HostTensorA = cutlass::HostTensor<typename B2bGemm::ElementA, typename B2bGemm::LayoutA>;
|
||||
using HostTensorB = cutlass::HostTensor<typename B2bGemm::ElementB, typename B2bGemm::LayoutB>;
|
||||
using HostTensorC = cutlass::HostTensor<typename B2bGemm::ElementC, typename B2bGemm::LayoutC>;
|
||||
using HostTensorScale = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
|
||||
using HostTensorZ = cutlass::HostTensor<ElementAccumulator, typename B2bGemm::LayoutC>;
|
||||
using HostTensorBias = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
|
||||
|
||||
int problem_count = (int)problem_sizes_0.size();
|
||||
|
||||
std::vector<HostTensorA> host_tensor_A0(problem_count);
|
||||
std::vector<HostTensorB> host_tensor_B0(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_C0(problem_count);
|
||||
std::vector<HostTensorScale> host_tensor_Scale0(problem_count);
|
||||
std::vector<HostTensorScale> host_tensor_Bias0(problem_count);
|
||||
std::vector<HostTensorB> host_tensor_B1(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_C1(problem_count);
|
||||
std::vector<HostTensorBias> host_tensor_Bias1(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_D1(problem_count);
|
||||
std::vector<HostTensorZ> host_tensor_Z(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_ref_D0(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_ref_D1(problem_count);
|
||||
|
||||
std::vector<typename HostTensorA::TensorRef> ref_A0(problem_count);
|
||||
std::vector<typename HostTensorB::TensorRef> ref_B0(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_C0(problem_count);
|
||||
std::vector<typename HostTensorScale::TensorRef> ref_Scale0(problem_count);
|
||||
std::vector<typename HostTensorScale::TensorRef> ref_Bias0(problem_count);
|
||||
std::vector<typename HostTensorB::TensorRef> ref_B1(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_C1(problem_count);
|
||||
std::vector<typename HostTensorBias::TensorRef> ref_Bias1(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_D1(problem_count);
|
||||
std::vector<typename HostTensorZ::TensorRef> ref_Z(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_ref_D0(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_ref_D1(problem_count);
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
auto problem_size_0 = problem_sizes_0[i];
|
||||
auto problem_size_1 = problem_sizes_1[i];
|
||||
|
||||
host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk());
|
||||
host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn());
|
||||
host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn());
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()});
|
||||
host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()});
|
||||
host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn());
|
||||
host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn());
|
||||
host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn());
|
||||
host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn());
|
||||
host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()});
|
||||
host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn());
|
||||
host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn());
|
||||
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017));
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
host_tensor_D1.at(i).host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
host_tensor_ref_D0.at(i).host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
host_tensor_ref_D1.at(i).host_view());
|
||||
|
||||
host_tensor_A0.at(i).sync_device();
|
||||
host_tensor_B0.at(i).sync_device();
|
||||
host_tensor_C0.at(i).sync_device();
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
host_tensor_Scale0.at(i).sync_device();
|
||||
host_tensor_Bias0.at(i).sync_device();
|
||||
host_tensor_B1.at(i).sync_device();
|
||||
host_tensor_C1.at(i).sync_device();
|
||||
host_tensor_Bias1.at(i).sync_device();
|
||||
host_tensor_D1.at(i).sync_device();
|
||||
host_tensor_ref_D0.at(i).sync_device();
|
||||
host_tensor_ref_D1.at(i).sync_device();
|
||||
|
||||
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
|
||||
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());;
|
||||
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
|
||||
ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref());
|
||||
ref_B1.at(i) = (host_tensor_B1.at(i).device_ref());
|
||||
ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)};
|
||||
ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref());
|
||||
ref_D1.at(i) = (host_tensor_D1.at(i).device_ref());
|
||||
ref_Z.at(i) = (host_tensor_Z.at(i).device_ref());
|
||||
ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref());
|
||||
ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref());
|
||||
}
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
cutlass::DeviceAllocation<typename HostTensorA::TensorRef> device_ref_A0(problem_count);
|
||||
device_ref_A0.copy_from_host(ref_A0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B0(problem_count);
|
||||
device_ref_B0.copy_from_host(ref_B0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C0(problem_count);
|
||||
device_ref_C0.copy_from_host(ref_C0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Scale0(problem_count);
|
||||
device_ref_Scale0.copy_from_host(ref_Scale0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Bias0(problem_count);
|
||||
device_ref_Bias0.copy_from_host(ref_Bias0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B1(problem_count);
|
||||
device_ref_B1.copy_from_host(ref_B1.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C1(problem_count);
|
||||
device_ref_C1.copy_from_host(ref_C1.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorBias::TensorRef> device_ref_Bias1(problem_count);
|
||||
device_ref_Bias1.copy_from_host(ref_Bias1.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_D1(problem_count);
|
||||
device_ref_D1.copy_from_host(ref_D1.data());
|
||||
|
||||
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_0(problem_count);
|
||||
device_problem_sizes_0.copy_from_host(problem_sizes_0.data());
|
||||
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_1(problem_count);
|
||||
device_problem_sizes_1.copy_from_host(problem_sizes_1.data());
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count);
|
||||
if (!threadblock_count) {
|
||||
std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
problem_count,
|
||||
device_problem_sizes_0.get(),
|
||||
device_problem_sizes_1.get(),
|
||||
device_ref_A0.get(),
|
||||
device_ref_B0.get(),
|
||||
device_ref_C0.get(),
|
||||
device_ref_Scale0.get(),
|
||||
device_ref_Bias0.get(),
|
||||
device_ref_B1.get(),
|
||||
device_ref_C1.get(),
|
||||
device_ref_D1.get(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
threadblock_count
|
||||
};
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.M = problem_size_1.M\n"
|
||||
<< " problem_size_0.N = problem_size_1.K\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||
}
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop);
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
host_tensor_D1.at(i).sync_host();;
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator>
|
||||
reference_gemm_1;
|
||||
|
||||
auto problem_size_0 = problem_sizes_0[i];
|
||||
auto problem_size_1 = problem_sizes_1[i];
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
ref_A0.at(i),
|
||||
ref_B0.at(i),
|
||||
ElementAccumulator(0), //beta = 0
|
||||
ref_Z.at(i),
|
||||
ref_Z.at(i),
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutC
|
||||
> (
|
||||
problem_size_0,
|
||||
ref_Z.at(i),
|
||||
ref_ref_D0.at(i),
|
||||
alpha0,
|
||||
ref_Scale0.at(i),
|
||||
ref_Bias0.at(i)
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
ref_ref_D0.at(i),
|
||||
ref_B1.at(i),
|
||||
beta1,
|
||||
{host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
ref_ref_D1.at(i)
|
||||
);
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view());
|
||||
}
|
||||
cudaDeviceSynchronize();
|
||||
host_tensor_ref_D0.at(i).sync_host();
|
||||
host_tensor_ref_D1.at(i).sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
host_tensor_ref_D1.at(i).host_view(),
|
||||
host_tensor_D1.at(i).host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed)
|
||||
{
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bGemm_device_fused.txt";
|
||||
std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl;
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
<< "GEMM " << i << " in group\n"
|
||||
<< "A0 =\n" << host_tensor_A0.at(i).host_view()
|
||||
<< "\nB0 =\n" << host_tensor_B0.at(i).host_view()
|
||||
<< "\nC0 =\n" << host_tensor_C0.at(i).host_view()
|
||||
<< "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n"
|
||||
<< "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n"
|
||||
<< "\nB1 =\n" << host_tensor_B1.at(i).host_view()
|
||||
<< "\nC1 =\n" << host_tensor_C1.at(i).host_view()
|
||||
<< "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n"
|
||||
<< "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view()
|
||||
<< "\nComputed =\n" << host_tensor_D1.at(i).host_view();
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -43,6 +43,7 @@
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
@ -76,9 +77,9 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
//
|
||||
|
||||
B2bInterleavedNonFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
@ -87,7 +88,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -95,7 +96,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
@ -128,73 +129,72 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
@ -285,13 +285,13 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
@ -327,36 +327,36 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
@ -364,7 +364,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -377,7 +377,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
@ -416,9 +416,9 @@ struct B2bInterleavedFusedGemmRun
|
||||
//
|
||||
|
||||
B2bInterleavedFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
@ -429,7 +429,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -437,11 +437,11 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
@ -470,78 +470,99 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
|
||||
// batch_count is used as split-k when mode is kGemm according
|
||||
// to the GemmUniversal interface
|
||||
|
||||
int batch_count = 1,
|
||||
|
||||
int64_t batch_stride_A0 = 0,
|
||||
int64_t batch_stride_B0 = 0,
|
||||
int64_t batch_stride_C0 = 0,
|
||||
int64_t batch_stride_B1 = 0,
|
||||
int64_t batch_stride_C1 = 0,
|
||||
int64_t batch_stride_D1 = 0,
|
||||
int64_t batch_stride_Bias0 = 0,
|
||||
int64_t batch_stride_Scale0 = 0,
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
|
||||
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
||||
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
//Reorder B0
|
||||
cutlass::reorder_column<16>(
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
|
||||
cutlass::reorder_column<InterleavedK_>(
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
// tensor_Bias0_batched.sync_device();
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
mode,
|
||||
problem_size_0,
|
||||
problem_size_1,
|
||||
tensor_A0.device_ref(),
|
||||
@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_B1_reordered.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1,
|
||||
batch_stride_Bias0,
|
||||
batch_stride_Scale0,
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
batch_count,
|
||||
};
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator
|
||||
>(
|
||||
problem_size_0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_A0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ElementAccumulator(0), //beta = 0
|
||||
reference_Z0.device_ref(),
|
||||
reference_Z0.device_ref(),
|
||||
ElementAccumulator(0)
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_C0,
|
||||
batch_stride_C0
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
cutlass::reference::device::TensorScaleBiasGemmBatched<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||
> (
|
||||
@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref()
|
||||
tensor_Bias0.device_ref(),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_C0,
|
||||
batch_stride_Scale0,
|
||||
batch_stride_Bias0
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
alpha1, //intermediate alpha=1
|
||||
reference_D0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B1.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
beta1, //beta = 0
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
reference_D1.device_ref(),
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
@ -713,7 +762,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -727,7 +776,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
|
||||
@ -119,8 +119,6 @@ template <
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// If true, kernel supports split-K with serial reduction
|
||||
bool SplitKSerial = false,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
@ -154,7 +152,6 @@ class B2bGemm {
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
@ -184,77 +181,11 @@ class B2bGemm {
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator,
|
||||
SmemAccumulator
|
||||
>::B2bGemmKernel;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord problem_size_0;
|
||||
GemmCoord problem_size_1;
|
||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B0;
|
||||
TensorRef<ElementC const, LayoutC> ref_C0;
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
||||
TensorRef<ElementC, LayoutC> ref_D1;
|
||||
typename EpilogueOutputOp0::Params epilogue0;
|
||||
typename EpilogueOutputOp1::Params epilogue1;
|
||||
int split_k_slices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord problem_size_0_,
|
||||
GemmCoord problem_size_1_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B0_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C0_,
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
||||
typename EpilogueOutputOp0::Params epilogue0_ =
|
||||
typename EpilogueOutputOp0::Params(),
|
||||
typename EpilogueOutputOp1::Params epilogue1_ =
|
||||
typename EpilogueOutputOp1::Params(),
|
||||
int split_k_slices_ = 1
|
||||
):
|
||||
problem_size_0(problem_size_0_),
|
||||
problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_),
|
||||
ref_Bias0(ref_Bias0_),
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
epilogue0(epilogue0_),
|
||||
epilogue1(epilogue1_),
|
||||
split_k_slices(split_k_slices_) {
|
||||
|
||||
}
|
||||
};
|
||||
using Arguments = typename B2bGemmKernel::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
@ -269,10 +200,6 @@ public:
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
if (!kSplitKSerial && args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
Status status = B2bGemmKernel::can_implement(
|
||||
args.problem_size_0,
|
||||
args.problem_size_1,
|
||||
@ -295,20 +222,14 @@ public:
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size_0,
|
||||
args.problem_size_0,
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
|
||||
|
||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
}
|
||||
args.batch_count);
|
||||
|
||||
return bytes;
|
||||
}
|
||||
@ -320,38 +241,17 @@ public:
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size_0,
|
||||
args.problem_size_0,
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.split_k_slices);
|
||||
args.batch_count);
|
||||
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
|
||||
// args.problem_size_1,
|
||||
// args.problem_size_1,
|
||||
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
|
||||
// args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial) {
|
||||
if (args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
size_t bytes = get_workspace_size(args);
|
||||
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if (args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
// args.batch_count);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename B2bGemmKernel::Params{
|
||||
args.mode,
|
||||
args.problem_size_0,
|
||||
args.problem_size_1,
|
||||
grid_shape,
|
||||
@ -363,6 +263,13 @@ public:
|
||||
args.ref_B1.non_const_ref(),
|
||||
args.ref_C1.non_const_ref(),
|
||||
args.ref_D1,
|
||||
args.batch_stride_A0,
|
||||
args.batch_stride_B0,
|
||||
args.batch_stride_B1,
|
||||
args.batch_stride_C1,
|
||||
args.batch_stride_D1,
|
||||
args.batch_stride_Bias0,
|
||||
args.batch_stride_Scale0,
|
||||
args.epilogue0,
|
||||
args.epilogue1,
|
||||
static_cast<int *>(workspace),
|
||||
@ -373,12 +280,6 @@ public:
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
}
|
||||
|
||||
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
|
||||
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
|
||||
@ -430,12 +331,12 @@ public:
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
@ -220,7 +220,6 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
@ -229,10 +228,6 @@ int main() {
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv int8 RF residency");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -39,7 +39,6 @@
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
|
||||
@ -219,20 +218,13 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm75_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv int8 shmem staging");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -0,0 +1,297 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example of running grouped back-to-back GEMMs when intermediate results are RF resident
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/base_grouped.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "kernel/default_b2b_gemm.h"
|
||||
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||
#include "b2b_grouped_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_0;
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_1;
|
||||
|
||||
// Constraints:
|
||||
// 1. Warp shape N must equal thread block shape N
|
||||
// 2. Problem size N must equal thread block shape N
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
bool reference_check;
|
||||
int alignment = 8;
|
||||
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes0;
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes1;
|
||||
|
||||
int problem_count;
|
||||
bool verbose;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
reference_check(true),
|
||||
problem_count(15),
|
||||
verbose(false)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("problems", problem_count, 15);
|
||||
cmd.get_cmd_line_argument("reference-check", reference_check, true);
|
||||
cmd.get_cmd_line_argument("verbose", verbose, false);
|
||||
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
|
||||
//
|
||||
// For now, randomly choose the problem sizes.
|
||||
//
|
||||
|
||||
int cmd_line_m = -1;
|
||||
int cmd_line_k = -1;
|
||||
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes0.reserve(problem_count);
|
||||
problem_sizes1.reserve(problem_count);
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
|
||||
int m = cmd_line_m;
|
||||
int k = cmd_line_k;
|
||||
|
||||
if (m < 1) {
|
||||
m = alignment * ((rand() % 256) + 1);
|
||||
}
|
||||
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 256) + 1);
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord problem0(m, ThreadblockShape0::kN, k);
|
||||
cutlass::gemm::GemmCoord problem1(m, ThreadblockShape1::kN, ThreadblockShape0::kN);
|
||||
|
||||
problem_sizes0.push_back(problem0);
|
||||
problem_sizes1.push_back(problem1);
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
print_problem_sizes();
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
|
||||
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
|
||||
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
|
||||
<< " back-to-back GEMMs are subject to.s"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --problems=<int> Number of individual GEMM problems (default: --problems=15)\n"
|
||||
<< " --m=<int> Sets the M dimension of both GEMMs for all groups. Otherwise, it is selected randomly\n"
|
||||
<< " --k=<int> Sets the K dimension of the first GEMM for all groups. Otherwise, it is selected randomly\n"
|
||||
<< " --verbose=<bool> If true, prints problem sizes.\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
|
||||
<< "# Runs a grouped B2b GEMM with 10 random problem sizes\n"
|
||||
<< "$ ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_grouped_f16_sm80_rf --groups=10\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void print_problem_sizes() {
|
||||
std::cout << std::endl;
|
||||
std::cout << "Executing " << problem_count << " independent back-to-back GEMMs in a group" << std::endl;
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
cutlass::gemm::GemmCoord problem0 = problem_sizes0.at(i);
|
||||
cutlass::gemm::GemmCoord problem1 = problem_sizes1.at(i);
|
||||
std::cout << "Problem " << i
|
||||
<< "\t\tGEMM0: " << problem0.m() << 'x' << problem0.n() << 'x' << problem0.k()
|
||||
<< "\t\tGEMM1: " << problem1.m() << 'x' << problem1.n() << 'x' << problem1.k()
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool run_fused_grouped_gemm_f16_sm80_rf_res() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
using GroupedThreadblockSwizzle = cutlass::gemm::threadblock::B2bGemmGroupedThreadblockSwizzle<
|
||||
ThreadblockShape0,
|
||||
cutlass::layout::RowMajor // LayoutC
|
||||
>;
|
||||
|
||||
const int kAlignment = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
const int kStages = 3;
|
||||
using B2bGemmKernel = cutlass::gemm::kernel::DefaultB2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
kAlignment,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
kAlignment,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
GroupedThreadblockSwizzle,
|
||||
kStages,
|
||||
cutlass::arch::OpMultiplyAdd
|
||||
>::B2bGemmKernel;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::BaseGrouped<B2bGemmKernel>;
|
||||
|
||||
B2bFusedGroupedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN Grouped GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm80_problem_sizes_0, gemm_f16_sm80_problem_sizes_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
gemm_f16_sm80_problem_sizes_0 = options.problem_sizes0;
|
||||
gemm_f16_sm80_problem_sizes_1 = options.problem_sizes1;
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_fused_grouped_gemm_f16_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "grouped gemm f16 RF residency");
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -195,7 +195,6 @@ bool run_fused_gemm_s8_rf_res() {
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
@ -204,9 +203,6 @@ int main() {
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -43,7 +43,6 @@
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576);
|
||||
@ -197,18 +196,13 @@ bool run_fused_gemm_s8_shmem() {
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8,
|
||||
&run_fused_gemm_s8_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm int8 shmem staing");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -152,7 +152,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
@ -161,7 +161,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
bool passed = fusedGemm.run(
|
||||
gemm_s8_sm80_problem_size_0,
|
||||
gemm_s8_sm80_problem_size_1,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1
|
||||
);
|
||||
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
return passed;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_sm80_rf_res_batch() {
|
||||
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64);
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
int batch_count = 2;
|
||||
int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k();
|
||||
int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n();
|
||||
int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n();
|
||||
int64_t batch_stride_Scale0 = 0;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(
|
||||
gemm_s8_sm80_problem_size_0,
|
||||
gemm_s8_sm80_problem_size_1,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
cutlass::gemm::GemmUniversalMode::kBatched,
|
||||
batch_count,
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_C0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1,
|
||||
batch_stride_Bias0,
|
||||
batch_stride_Scale0
|
||||
);
|
||||
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8_sm80,
|
||||
&run_fused_gemm_s8_sm80_rf_res
|
||||
&run_fused_gemm_s8_sm80_rf_res,
|
||||
&run_fused_gemm_s8_sm80_rf_res_batch
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -151,7 +151,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
@ -160,7 +160,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
@ -168,7 +168,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
|
||||
@ -40,19 +40,66 @@
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
|
||||
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Utility struct for returning the type of the problem visitor used by the swizzling function,
|
||||
/// if it is a grouped swizzling function, or a default visitor. This is used only for defining
|
||||
/// the parameters of the problem visitor used in GroupedParams.
|
||||
template <
|
||||
typename B2bMma_,
|
||||
typename ThreadblockSwizzle_,
|
||||
typename Enable = void
|
||||
>
|
||||
struct ProblemVisitorOrDefault;
|
||||
|
||||
/// Return a generic problem visitor for GEMM problems
|
||||
template <
|
||||
typename B2bMma_,
|
||||
typename ThreadblockSwizzle_
|
||||
>
|
||||
struct ProblemVisitorOrDefault<B2bMma_,
|
||||
ThreadblockSwizzle_,
|
||||
typename platform::enable_if<
|
||||
! cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
|
||||
>::type> {
|
||||
using value = B2bGemmGroupedProblemVisitor<typename B2bMma_::Shape,
|
||||
GroupScheduleMode::kDeviceOnly,
|
||||
128,
|
||||
128,
|
||||
platform::is_same<typename B2bMma_::LayoutC,
|
||||
cutlass::layout::ColumnMajor>::value>;
|
||||
};
|
||||
|
||||
/// Return the problem visitor specified by the swizzling function
|
||||
template <
|
||||
typename B2bMma_,
|
||||
typename ThreadblockSwizzle_
|
||||
>
|
||||
struct ProblemVisitorOrDefault<B2bMma_,
|
||||
ThreadblockSwizzle_,
|
||||
typename platform::enable_if<
|
||||
cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
|
||||
>::type> {
|
||||
using value = typename ThreadblockSwizzle_::ProblemVisitor;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct B2bGemm {
|
||||
|
||||
@ -61,14 +108,184 @@ struct B2bGemm {
|
||||
using OutputOp0 = typename B2bMma::OutputOp;
|
||||
using OutputOp1 = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
using ElementA0 = typename B2bMma::IteratorA0::Element;
|
||||
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
|
||||
using ElementB0 = typename B2bMma::IteratorB0::Element;
|
||||
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
|
||||
using ElementB1 = typename B2bMma::IteratorB1::Element;
|
||||
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
||||
|
||||
/// Data types needed for higher-level containers. In some cases, a single type must be exposed
|
||||
/// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from
|
||||
/// the second GEMM (other than for ElementA/ElementB)
|
||||
using ElementA = typename B2bMma::IteratorA0::Element;
|
||||
using LayoutA = typename B2bMma::IteratorA0::Layout;
|
||||
using ElementB = typename B2bMma::IteratorB0::Element;
|
||||
using LayoutB = typename B2bMma::IteratorB0::Layout;
|
||||
|
||||
static ComplexTransform const kTransformA = B2bMma::kTransformA;
|
||||
static ComplexTransform const kTransformB = B2bMma::kTransformB;
|
||||
using Operator = typename B2bMma::Operator0;
|
||||
|
||||
using OperatorClass = typename Operator::OperatorClass;
|
||||
using ThreadblockShape = typename B2bMma::Shape0;
|
||||
using WarpShape = typename Operator::Shape;
|
||||
using InstructionShape = typename Operator::InstructionShape;
|
||||
using ArchTag = typename B2bMma::ArchTag;
|
||||
|
||||
static int const kStages = B2bMma::kStages;
|
||||
static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using Mma = B2bMma;
|
||||
using EpilogueOutputOp = OutputOp1;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount0 = typename B2bMma::WarpCount0;
|
||||
static int const kThreadCount = 32 * WarpCount0::kCount;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size_0;
|
||||
GemmCoord problem_size_1;
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0;
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||
int64_t batch_stride_A0;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_D1;
|
||||
int64_t batch_stride_Bias0;
|
||||
int64_t batch_stride_Scale0;
|
||||
typename OutputOp0::Params epilogue0;
|
||||
typename OutputOp1::Params epilogue1;
|
||||
int batch_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() : mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmUniversalMode mode_,
|
||||
GemmCoord problem_size_0_,
|
||||
GemmCoord problem_size_1_,
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0_,
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_,
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1_,
|
||||
int64_t batch_stride_A0_,
|
||||
int64_t batch_stride_B0_,
|
||||
int64_t batch_stride_B1_,
|
||||
int64_t batch_stride_C1_,
|
||||
int64_t batch_stride_D1_,
|
||||
int64_t batch_stride_Bias0_,
|
||||
int64_t batch_stride_Scale0_,
|
||||
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
|
||||
int batch_count_ = 1
|
||||
):
|
||||
mode(mode_),
|
||||
problem_size_0(problem_size_0_),
|
||||
problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_),
|
||||
ref_Bias0(ref_Bias0_),
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
batch_stride_A0(batch_stride_A0_),
|
||||
batch_stride_B0(batch_stride_B0_),
|
||||
batch_stride_B1(batch_stride_B1_),
|
||||
batch_stride_C1(batch_stride_C1_),
|
||||
batch_stride_D1(batch_stride_D1_),
|
||||
batch_stride_Bias0(batch_stride_Bias0_),
|
||||
batch_stride_Scale0(batch_stride_Scale0_),
|
||||
epilogue0(epilogue0_),
|
||||
epilogue1(epilogue1_),
|
||||
batch_count(batch_count_) {
|
||||
}
|
||||
};
|
||||
|
||||
// Arguments structure for grouped B2B problems
|
||||
struct GroupedArguments {
|
||||
GemmCoord* problem_size_0;
|
||||
GemmCoord* problem_size_1;
|
||||
typename B2bMma::IteratorA0::TensorRef* ref_A0;
|
||||
typename B2bMma::IteratorB0::TensorRef* ref_B0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
|
||||
typename B2bMma::IteratorB1::TensorRef* ref_B1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||
|
||||
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||
// the parameter here is not a pointer.
|
||||
typename OutputOp0::Params epilogue0;
|
||||
typename OutputOp1::Params epilogue1;
|
||||
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedArguments(
|
||||
int problem_count,
|
||||
GemmCoord* problem_size_0_,
|
||||
GemmCoord* problem_size_1_,
|
||||
typename B2bMma::IteratorA0::TensorRef* ref_A0_,
|
||||
typename B2bMma::IteratorB0::TensorRef* ref_B0_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_,
|
||||
typename B2bMma::IteratorB1::TensorRef* ref_B1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1_,
|
||||
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
|
||||
int threadblock_count = 0
|
||||
) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_),
|
||||
problem_count(problem_count),
|
||||
threadblock_count(threadblock_count)
|
||||
{}
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmUniversalMode mode;
|
||||
cutlass::gemm::GemmCoord problem_size_0;
|
||||
cutlass::gemm::GemmCoord problem_size_1;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
@ -89,6 +306,13 @@ struct B2bGemm {
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
int64_t batch_stride_A0;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_D1;
|
||||
int64_t batch_stride_Bias0;
|
||||
int64_t batch_stride_Scale0;
|
||||
int *semaphore;
|
||||
int gemm_k_iterations_0;
|
||||
int gemm_k_size_0;
|
||||
@ -100,11 +324,12 @@ struct B2bGemm {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmUniversalMode mode,
|
||||
cutlass::gemm::GemmCoord const & problem_size_0,
|
||||
cutlass::gemm::GemmCoord const & problem_size_1,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
@ -116,14 +341,22 @@ struct B2bGemm {
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
|
||||
int64_t batch_stride_A0,
|
||||
int64_t batch_stride_B0,
|
||||
int64_t batch_stride_B1,
|
||||
int64_t batch_stride_C1,
|
||||
int64_t batch_stride_D1,
|
||||
int64_t batch_stride_Bias0,
|
||||
int64_t batch_stride_Scale0,
|
||||
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
||||
int *workspace = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size_0(problem_size_0),
|
||||
problem_size_1(problem_size_1),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)),
|
||||
params_A0(ref_A0.layout()),
|
||||
ref_A0(ref_A0),
|
||||
params_B0(ref_B0.layout()),
|
||||
@ -138,6 +371,13 @@ struct B2bGemm {
|
||||
ref_C1(ref_C1),
|
||||
params_D1(ref_D1.layout()),
|
||||
ref_D1(ref_D1),
|
||||
batch_stride_A0(batch_stride_A0),
|
||||
batch_stride_B0(batch_stride_B0),
|
||||
batch_stride_B1(batch_stride_B1),
|
||||
batch_stride_C1(batch_stride_C1),
|
||||
batch_stride_D1(batch_stride_D1),
|
||||
batch_stride_Bias0(batch_stride_Bias0),
|
||||
batch_stride_Scale0(batch_stride_Scale0),
|
||||
output_op_0(output_op_0),
|
||||
output_op_1(output_op_1) {
|
||||
|
||||
@ -152,6 +392,81 @@ struct B2bGemm {
|
||||
}
|
||||
};
|
||||
|
||||
struct GroupedParams {
|
||||
cutlass::gemm::GemmCoord* problem_size_0;
|
||||
cutlass::gemm::GemmCoord* problem_size_1;
|
||||
cutlass::gemm::GemmCoord* grid_tiled_shape;
|
||||
typename B2bMma::IteratorA0::TensorRef* ref_A0;
|
||||
typename B2bMma::IteratorB0::TensorRef* ref_B0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
|
||||
typename B2bMma::IteratorB1::TensorRef* ref_B1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||
|
||||
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||
// the parameter here is not a pointer.
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
|
||||
using ProblemVisitor = typename detail::ProblemVisitorOrDefault<B2bMma, ThreadblockSwizzle>::value;
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
int* workspace;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedParams() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedParams(
|
||||
GroupedArguments const &args,
|
||||
void *workspace = nullptr,
|
||||
int tile_count = 0
|
||||
) :
|
||||
problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1),
|
||||
ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0),
|
||||
ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1),
|
||||
output_op_0(args.epilogue0), output_op_1(args.epilogue1),
|
||||
problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count),
|
||||
threadblock_count(args.threadblock_count),
|
||||
workspace(reinterpret_cast<int*>(workspace)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void transpose() {
|
||||
// Only row-major outputs are currently supported, so no transpose is performed
|
||||
}
|
||||
|
||||
/// Returns non-grouped paramaters to be used as input to the kernel-level
|
||||
/// operator for the problem indicated by problem_visitor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params to_single_params(const ProblemVisitor& problem_visitor) const {
|
||||
GemmCoord problem_size0 = problem_visitor.problem_size0();
|
||||
GemmCoord problem_size1 = problem_visitor.problem_size1();
|
||||
int32_t idx = problem_visitor.problem_index();
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1);
|
||||
|
||||
return Params(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size0,
|
||||
problem_size1,
|
||||
grid_shape,
|
||||
ref_A0[idx],
|
||||
ref_B0[idx],
|
||||
ref_C0[idx],
|
||||
ref_Scale0[idx],
|
||||
ref_Bias0[idx],
|
||||
ref_B1[idx],
|
||||
ref_C1[idx],
|
||||
ref_D1[idx],
|
||||
0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported
|
||||
output_op_0,
|
||||
output_op_1,
|
||||
workspace
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename B2bMma::B2bMmaSharedStorage main_loop;
|
||||
@ -163,7 +478,7 @@ struct B2bGemm {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
B2bGemm() { }
|
||||
B2bGemm() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
@ -223,7 +538,7 @@ struct B2bGemm {
|
||||
|
||||
if(problem_size_0.n() > B2bMma::Shape0::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
|
||||
if(problem_size_1.n() > B2bMma::Shape1::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
@ -233,9 +548,13 @@ struct B2bGemm {
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
run_with_swizzle(params, shared_storage, threadblock_swizzle);
|
||||
}
|
||||
|
||||
/// Executes one GEMM with an externally-provided swizzling function
|
||||
CUTLASS_DEVICE
|
||||
void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
@ -247,37 +566,64 @@ struct B2bGemm {
|
||||
return;
|
||||
}
|
||||
|
||||
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
|
||||
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
|
||||
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
|
||||
|
||||
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
|
||||
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
|
||||
|
||||
int offset_k_0 = 0;
|
||||
int offset_k_1 = 0;
|
||||
|
||||
int problem_size_k_0 = params.problem_size_0.k();
|
||||
int problem_size_k_1 = params.problem_size_1.k();
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
problem_size_k_0 = min(
|
||||
problem_size_k_0,
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
problem_size_k_1 = min(
|
||||
problem_size_k_1,
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
||||
|
||||
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
|
||||
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
|
||||
}
|
||||
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
|
||||
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
|
||||
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
|
||||
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
|
||||
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A0{
|
||||
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
||||
offset_k_0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B0{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
||||
offset_k_0,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B1{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_1,
|
||||
offset_k_1,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape1::kN
|
||||
};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k_0 = min(
|
||||
params.problem_size_0.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k_1 = min(
|
||||
params.problem_size_1.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||
|
||||
|
||||
// Compute position within threadblock
|
||||
@ -286,34 +632,33 @@ struct B2bGemm {
|
||||
// Construct iterators to A and B operands
|
||||
typename B2bMma::IteratorA0 iterator_A0(
|
||||
params.params_A0,
|
||||
params.ref_A0.data(),
|
||||
ptr_A0,
|
||||
{params.problem_size_0.m(), problem_size_k_0},
|
||||
thread_idx,
|
||||
tb_offset_A0);
|
||||
|
||||
typename B2bMma::IteratorB0 iterator_B0(
|
||||
params.params_B0,
|
||||
params.ref_B0.data(),
|
||||
ptr_B0,
|
||||
{problem_size_k_0, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
tb_offset_B0);
|
||||
|
||||
typename B2bMma::IteratorB1 iterator_B1(
|
||||
params.params_B1,
|
||||
params.ref_B1.data(),
|
||||
ptr_B1,
|
||||
{problem_size_k_1, params.problem_size_1.n()},
|
||||
thread_idx,
|
||||
tb_offset_B1);
|
||||
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
// Construct iterators to accumulator scale/bias vector
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
||||
params.ref_Scale0.data(),
|
||||
ptr_Scale0,
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
@ -323,7 +668,7 @@ struct B2bGemm {
|
||||
);
|
||||
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
||||
params.ref_Bias0.data(),
|
||||
ptr_Bias0,
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
@ -332,14 +677,17 @@ struct B2bGemm {
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
OutputOp0 output_op_0(params.output_op_0);
|
||||
|
||||
if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle>::value) {
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
|
||||
|
||||
@ -349,11 +697,9 @@ struct B2bGemm {
|
||||
src_accum.clear();
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
}
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
@ -376,23 +722,32 @@ struct B2bGemm {
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
|
||||
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
if (params.grid_tiled_shape.k() > 1) {
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
|
||||
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
params.ref_C1.data(),
|
||||
ptr_C1,
|
||||
params.problem_size_1.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
@ -401,21 +756,21 @@ struct B2bGemm {
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D1(
|
||||
params.params_D1,
|
||||
params.ref_D1.data(),
|
||||
ptr_D1,
|
||||
params.problem_size_1.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C1 = iterator_D1;
|
||||
@ -427,14 +782,14 @@ struct B2bGemm {
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
||||
|
||||
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
@ -457,4 +812,3 @@ struct B2bGemm {
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
@ -0,0 +1,157 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Scheduler for grouped B2b GEMMs
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ThreadblockShape,
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount,
|
||||
int ThreadCount,
|
||||
bool Transposed = false>
|
||||
struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor<
|
||||
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount> {
|
||||
|
||||
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
||||
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
|
||||
using BaseParams = typename Base::Params;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1;
|
||||
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1;
|
||||
int32_t problem_count;
|
||||
void const *workspace;
|
||||
int32_t tile_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
|
||||
problem_count(0), workspace(nullptr), tile_count(0) { }
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0,
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1,
|
||||
int32_t problem_count,
|
||||
void const *workspace = nullptr,
|
||||
int32_t tile_count = 0
|
||||
):
|
||||
problem_sizes0(problem_sizes0),
|
||||
problem_sizes1(problem_sizes1),
|
||||
problem_count(problem_count),
|
||||
workspace(workspace),
|
||||
tile_count(tile_count)
|
||||
{}
|
||||
|
||||
/// Convert the B2b-GEMM-specific parameters to those used by the base class
|
||||
CUTLASS_HOST_DEVICE
|
||||
BaseParams to_base() const {
|
||||
return BaseParams(// Set problem_sizes as problem_sizes0 because these determine
|
||||
// shape of the grid used in the non-grouped B2b GEMM
|
||||
problem_sizes0,
|
||||
problem_count,
|
||||
workspace,
|
||||
tile_count);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
B2bGemmGroupedProblemVisitor(
|
||||
Params const ¶ms_,
|
||||
SharedStorage &shared_storage_,
|
||||
int32_t block_idx
|
||||
): Base (
|
||||
params_.to_base(),
|
||||
shared_storage_, block_idx),
|
||||
problem_sizes0(params_.problem_sizes0),
|
||||
problem_sizes1(params_.problem_sizes1)
|
||||
{}
|
||||
|
||||
/// Returns the problem size 0 for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size0() const {
|
||||
GemmCoord problem = problem_sizes0[this->problem_idx];
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
|
||||
/// Returns the problem size 1 for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size1() const {
|
||||
GemmCoord problem = problem_sizes1[this->problem_idx];
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -30,10 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
@ -63,7 +63,9 @@
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "kernel/grouped.h"
|
||||
#include "threadblock/default_b2b_mma.h"
|
||||
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -73,6 +75,9 @@ namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle<T>;
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
@ -114,12 +119,12 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Stage accumulator in shared memory
|
||||
bool SmemAccumulator = false
|
||||
bool SmemAccumulator = false,
|
||||
/// Whether or not the operation is grouped
|
||||
typename Enable = void
|
||||
>
|
||||
struct DefaultB2bGemm;
|
||||
|
||||
@ -161,17 +166,77 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator> {
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
/// Partial specialization for Ampere Architecture with grouped operation
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator, false, typename platform::enable_if<IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
@ -188,7 +253,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using UnderlyingB2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
|
||||
using B2bGemmKernel = kernel::GroupedKernel<UnderlyingB2bGemmKernel>;
|
||||
};
|
||||
|
||||
|
||||
@ -228,8 +295,6 @@ template <
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
@ -249,8 +314,9 @@ struct DefaultB2bGemm<
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator
|
||||
Operator,
|
||||
false,
|
||||
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type
|
||||
> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
@ -274,7 +340,7 @@ struct DefaultB2bGemm<
|
||||
Operator,
|
||||
EpilogueOutputOp0
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
@ -287,7 +353,7 @@ struct DefaultB2bGemm<
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
@ -323,20 +389,17 @@ template <
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator> {
|
||||
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -360,7 +423,7 @@ struct DefaultB2bGemm<
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -396,19 +459,17 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
|
||||
ThreadblockSwizzle, 2, Operator, false,
|
||||
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -418,7 +479,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
|
||||
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
|
||||
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -430,7 +491,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -30,10 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
@ -112,22 +112,19 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator, true> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -139,10 +136,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Architecture
|
||||
@ -179,8 +175,6 @@ template <
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
@ -200,7 +194,6 @@ struct DefaultB2bGemm<
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
true
|
||||
> {
|
||||
@ -228,7 +221,7 @@ struct DefaultB2bGemm<
|
||||
false,
|
||||
true
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
@ -241,7 +234,7 @@ struct DefaultB2bGemm<
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
@ -277,20 +270,17 @@ template <
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator, true> {
|
||||
Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -314,7 +304,7 @@ struct DefaultB2bGemm<
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -350,19 +340,16 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
|
||||
ThreadblockSwizzle, 2, Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -371,9 +358,9 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -385,7 +372,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
168
examples/13_two_tensor_op_fusion/kernel/grouped.h
Normal file
168
examples/13_two_tensor_op_fusion/kernel/grouped.h
Normal file
@ -0,0 +1,168 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief High-level interface for running a grouped version of a CUTLASS kernel
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// High-level interface for running a grouped version of a CUTLASS kernel
|
||||
template <
|
||||
typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate
|
||||
>
|
||||
struct GroupedKernel {
|
||||
public:
|
||||
|
||||
using BaseKernel = BaseKernel_;
|
||||
using Epilogue = typename BaseKernel::Epilogue;
|
||||
|
||||
/// Types that need to be exported to work properly with device::BaseGrouped
|
||||
using ElementA = typename BaseKernel::ElementA;
|
||||
using LayoutA = typename BaseKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
|
||||
static int const kAlignmentA = BaseKernel::kAlignmentA;
|
||||
|
||||
using ElementB = typename BaseKernel::ElementB;
|
||||
using LayoutB = typename BaseKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
|
||||
static int const kAlignmentB = BaseKernel::kAlignmentB;
|
||||
|
||||
using ElementC = typename BaseKernel::ElementC;
|
||||
using LayoutC = typename BaseKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
static int const kAlignmentC = BaseKernel::kAlignmentC;
|
||||
|
||||
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
|
||||
|
||||
using Operator = typename BaseKernel::Operator;
|
||||
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename WarpMmaOperator::MathOperator;
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
using ThreadblockShape = typename BaseKernel::Mma::Shape;
|
||||
using WarpShape = typename BaseKernel::WarpShape;
|
||||
using InstructionShape = typename BaseKernel::InstructionShape;
|
||||
static int const kStages = BaseKernel::Mma::kStages;
|
||||
|
||||
using Mma = typename BaseKernel::Mma;
|
||||
|
||||
using Arguments = typename BaseKernel::GroupedArguments;
|
||||
using Params = typename BaseKernel::GroupedParams;
|
||||
using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor;
|
||||
|
||||
static int const kThreadCount = BaseKernel::kThreadCount;
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage {
|
||||
typename BaseKernel::SharedStorage kernel;
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GroupedKernel() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes a kernel-level GEMM in a loop
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
if (ProblemVisitor::kTransposed) {
|
||||
params.transpose();
|
||||
}
|
||||
|
||||
BaseKernel mma;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (swizzle.problem_visitor.next_tile()) {
|
||||
|
||||
typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor);
|
||||
mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle);
|
||||
|
||||
// Next tile
|
||||
swizzle.problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm(
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
MatrixCoord output_coord(
|
||||
@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm(
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
|
||||
if(tensor_bias.good())
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, coord.column()});
|
||||
|
||||
tensor_out.at(coord) = convert_op(
|
||||
@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
|
||||
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
|
||||
int kMblock = 4,
|
||||
int kNblock = 4
|
||||
>
|
||||
__global__ void TensorScaleBiasGemmBatched(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRefIn tensor_in, ///< input tensor
|
||||
TensorRefOut tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias, ///< bias tensor
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_tensor_in = 0,
|
||||
int64_t batch_stride_tensor_out = 0,
|
||||
int64_t batch_stride_tensor_scale = 0,
|
||||
int64_t batch_stride_tensor_bias = 0
|
||||
) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
||||
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
||||
int batch_idx = blockIdx.z;
|
||||
|
||||
tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
|
||||
tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
|
||||
tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
|
||||
tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
|
||||
|
||||
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kNblock; j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
|
||||
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, coord.column()});
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, coord.column()});
|
||||
|
||||
tensor_out.at(coord) = convert_op(
|
||||
scale * ScalarType(tensor_in.at(coord)) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
|
||||
tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
|
||||
tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
|
||||
tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d(
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
||||
@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d(
|
||||
int64_t npq = npq_start + m;
|
||||
|
||||
thread_n[m] = int(npq / PQ);
|
||||
|
||||
|
||||
int64_t residual = npq % PQ;
|
||||
thread_p[m] = int(residual / problem_size.Q);
|
||||
thread_q[m] = int(residual % problem_size.Q);
|
||||
@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d(
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, thread_k});
|
||||
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
if(tensor_bias.good())
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, thread_k});
|
||||
|
||||
|
||||
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
||||
scale * ScalarType(
|
||||
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
|
||||
) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
|
||||
);
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
typename ElementOut, ///< Output Type
|
||||
typename Layout, ///< Layout of input/output tensor
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename LayoutScaleBias, ///< Layout of scale and bias
|
||||
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
|
||||
>
|
||||
void TensorScaleBiasGemmBatched(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
|
||||
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_bias, ///< bias tensor
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_tensor_in = 0,
|
||||
int64_t batch_stride_tensor_out = 0,
|
||||
int64_t batch_stride_tensor_scale = 0,
|
||||
int64_t batch_stride_tensor_bias = 0
|
||||
) {
|
||||
|
||||
int const kMblock = 4;
|
||||
int const kNblock = 4;
|
||||
|
||||
dim3 block(16, 8);
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
||||
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
||||
batch_count % std::numeric_limits<uint16_t>::max()
|
||||
);
|
||||
|
||||
kernel::TensorScaleBiasGemmBatched<
|
||||
TensorRef<ElementIn, Layout>,
|
||||
TensorRef<ElementOut, Layout>,
|
||||
ScalarType,
|
||||
TensorRef<ScalarType, LayoutScaleBias>,
|
||||
ConvertOp,
|
||||
kMblock,
|
||||
kNblock
|
||||
><<< grid, block >>> (
|
||||
problem_size,
|
||||
tensor_in,
|
||||
tensor_out,
|
||||
alpha,
|
||||
tensor_scale,
|
||||
tensor_bias,
|
||||
batch_count,
|
||||
batch_stride_tensor_in,
|
||||
batch_stride_tensor_out,
|
||||
batch_stride_tensor_scale,
|
||||
batch_stride_tensor_bias
|
||||
);
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
|
||||
@ -119,8 +119,10 @@ public:
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
using IteratorA = IteratorA0;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
using IteratorB = IteratorB0;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
@ -139,6 +141,10 @@ public:
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
///< Export Policy0 as the threadblock-level Mma's policy
|
||||
using Policy = Policy0;
|
||||
using Shape = Shape0;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
@ -188,6 +194,10 @@ public:
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Complex transform exports needed by higher-level kernels
|
||||
static ComplexTransform const kTransformA = kTransformA0;
|
||||
static ComplexTransform const kTransformB = kTransformB0;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
@ -641,6 +651,11 @@ public:
|
||||
|
||||
}
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
// 2nd Gemm
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
@ -871,7 +886,10 @@ public:
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@ -121,8 +121,10 @@ public:
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
using IteratorA = IteratorA0;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
using IteratorB = IteratorB0;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
@ -141,6 +143,10 @@ public:
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
///< Export Policy0 as the threadblock-level Mma's policy
|
||||
using Policy = Policy0;
|
||||
using Shape = Shape0;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
@ -194,6 +200,10 @@ public:
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Complex transform exports needed by higher-level kernels
|
||||
static ComplexTransform const kTransformA = kTransformA0;
|
||||
static ComplexTransform const kTransformB = kTransformB0;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
@ -664,6 +674,11 @@ public:
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
@ -855,7 +870,10 @@ public:
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@ -126,7 +126,9 @@ public:
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA0;
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB0;
|
||||
using Policy0 = Policy0_; ///< Policy describing tuning details
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
@ -139,6 +141,8 @@ public:
|
||||
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy describing tuning details
|
||||
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
|
||||
using Shape = Shape1;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
@ -195,6 +199,10 @@ public:
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Complex transform exports needed by higher-level kernels
|
||||
static ComplexTransform const kTransformA = kTransformA0;
|
||||
static ComplexTransform const kTransformB = kTransformB0;
|
||||
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
|
||||
@ -128,7 +128,9 @@ public:
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA0;
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB0;
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using Policy0 = Policy0_; ///< Policy0 describing tuning details
|
||||
|
||||
@ -141,6 +143,8 @@ public:
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
||||
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
|
||||
using Shape = Shape1;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
@ -192,6 +196,10 @@ public:
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Complex transform exports needed by higher-level kernels
|
||||
static ComplexTransform const kTransformA = kTransformA0;
|
||||
static ComplexTransform const kTransformB = kTransformB0;
|
||||
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
|
||||
@ -0,0 +1,125 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements several threadblock-swizzling functions for grouped kernels
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct GroupedThreadblockSwizzleBase {};
|
||||
|
||||
/// Helper for determining if a swizzling function is specialized for grouped operation
|
||||
template <typename ThreadblockSwizzle>
|
||||
struct IsGroupedSwizzle {
|
||||
static bool const value = cutlass::platform::is_base_of<GroupedThreadblockSwizzleBase, ThreadblockSwizzle>::value;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Swizzling function for grouped kernels
|
||||
template <typename ProblemVisitor_>
|
||||
struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase {
|
||||
|
||||
using ProblemVisitor = ProblemVisitor_;
|
||||
ProblemVisitor problem_visitor;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedThreadblockSwizzle(typename ProblemVisitor::Params& params,
|
||||
typename ProblemVisitor::SharedStorage& shared_storage,
|
||||
int block_idx) : problem_visitor(params, shared_storage, block_idx) {}
|
||||
|
||||
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
|
||||
CUTLASS_DEVICE
|
||||
GemmCoord get_tile_offset(int /*log_tile*/) const {
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
return GemmCoord(int(threadblock_idx / grid_shape.n()),
|
||||
int(threadblock_idx % grid_shape.n()),
|
||||
0);
|
||||
}
|
||||
|
||||
/// Dummy method to satisfy API for threadblock swizzling functions
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int get_log_tile(GemmCoord /*tiled_shape*/) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename ThreadblockShape,
|
||||
typename LayoutC,
|
||||
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||||
int PrefetchTileCount = 128,
|
||||
int ThreadCount = PrefetchTileCount>
|
||||
struct B2bGemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle<
|
||||
cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount,
|
||||
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value
|
||||
>
|
||||
> {
|
||||
using Base = GroupedThreadblockSwizzle<cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount,
|
||||
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value>>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
B2bGemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params,
|
||||
typename Base::ProblemVisitor::SharedStorage& shared_storage,
|
||||
int block_idx) : Base(params, shared_storage, block_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -31,83 +31,181 @@
|
||||
|
||||
/**
|
||||
|
||||
This example shows how to run convolution kernels using functions and data structures
|
||||
provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU.
|
||||
This example shows how to run CUTLASS's convolution kernels
|
||||
based on the Implicit GEMM algorithm, that use the Tensor Cores
|
||||
on an NVIDIA Ampere GPU.
|
||||
|
||||
Writing a single high performance convolution kernel is hard but do-able. Whereas writing
|
||||
high performance kernels at scale which works for multiple problem sizes with good abstractions is
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
|
||||
multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance
|
||||
of GPU easily.
|
||||
Writing a single high-performance convolution kernel is hard enough,
|
||||
let alone writing kernels that perform well for multiple problem sizes
|
||||
and use good software abstractions.
|
||||
CUTLASS provides simplified abstractions
|
||||
to compose multiple sections of a convolution kernel.
|
||||
When used properly, the kernels can reach peak GPU performance.
|
||||
|
||||
CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp
|
||||
and thread-block level, they compute on their own tile-size with higher level of tile sizes being
|
||||
composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used
|
||||
to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute
|
||||
threadblock-tile (tile size computed by a threadblock).
|
||||
CUTLASS divides a kernel into hierarchical composable sections
|
||||
for each level of the GPU hardware hierarchy:
|
||||
thread, warp, and threadblock.
|
||||
Each section computes on its own tile shape,
|
||||
with each higher level's tile shape
|
||||
being composed from lower-level tile shapes.
|
||||
Multiple thread tiles (the tile shape each thread computes)
|
||||
can be used to form warp tiles (the tile shape each warp computes),
|
||||
and multiple warp tiles can be used to compute threadblock tiles
|
||||
(the tile shape computed by a threadblock).
|
||||
|
||||
In thie example, we split variable initialization into
|
||||
1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel
|
||||
can view them (logical to physical mapping)
|
||||
2. Setting up computation properties : describes how the above set tensors will be used to compute
|
||||
output of convolution.
|
||||
In this example, we split variable initialization into two parts.
|
||||
|
||||
First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along
|
||||
with alpha, beta as the equation for convolution is C = alpha * Conv2dFprop(A, B) + beta * C. In CUTLASS,
|
||||
the kernels first compute Conv2dFprop(A, B) and leave the rest of the computation to end of the kernel as
|
||||
alpha * X + beta * C is a simple element-wise operation on X (Conv2dFprop(A, B)) and C. We call this as
|
||||
epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to
|
||||
ElementComputeEpilogue = float. We use the data type for elements in input tensor A and B as
|
||||
cutlass::half_t. We convey this to CUTLASS kernel by initializing template variables ElementAccumulator (float),
|
||||
ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), ElementInputB (cutlass::half_t),
|
||||
ElementOutput (float). Communicating just the data type is not enough. As the data is laid out
|
||||
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
|
||||
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup
|
||||
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
|
||||
variable EpilogueOp, which takes the data type of output ElementOutput (float), the number of
|
||||
elements per vector memory access (8), data type of accumulator (float) and data type of
|
||||
computation of linear combination (alpha * X + beta * C).
|
||||
1. Setting up data properties: describes how tensors are laid out in the memory
|
||||
and how the kernel can view them (logical to physical mapping)
|
||||
|
||||
Now that we setup the properties of data, we have to setup properties of computation.
|
||||
2. Setting up computation properties: describes how the above tensors
|
||||
will be used to compute the output of convolution
|
||||
|
||||
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x64,
|
||||
64x64x64, 16x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it
|
||||
internally deduces the amount of threads needed per thread-block, amount of shared memory, storing
|
||||
data in bank-conflict free manner, and ton of other variables required to compose, initialize and
|
||||
launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer
|
||||
from understanding and coding complicated hardware optimizations which can easily go wrong.
|
||||
We begin by setting up the data types
|
||||
of all the input and output elements of a convolution.
|
||||
A convolution computes
|
||||
C = alpha * Conv2dFprop(A, B) + beta * C,
|
||||
so we set up data types for the input tensor A,
|
||||
weights tensor B, output tensor C,
|
||||
and the scaling factors alpha and beta.
|
||||
CUTLASS divides the convolution into two parts:
|
||||
the "mainloop" that computes X = Conv2dFprop(A, B),
|
||||
and the "epilogue" that computes C = alpha * X + beta * C.
|
||||
The epilogue is an element-wise operation on X and C.
|
||||
In this case, it is a linear combination,
|
||||
but other epilogues are possible.
|
||||
|
||||
CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines
|
||||
constitute the whole process of loading input data from global memory to shared memory, loading data
|
||||
from shared memory to registers, doing matrix multiplication, store to global memory. The below flow
|
||||
sequence shows a typical mma multistage pipeline.
|
||||
(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h)
|
||||
In this example, we want
|
||||
|
||||
tensor in global memory --cp_async--> tile in shared memory --smem loads--> registers
|
||||
--mma--> registers --global stores--> output to global memory
|
||||
* the scaling factors alpha and beta to be float,
|
||||
|
||||
NVIDIA Ampere uses `cp_async` to build multistage software pipeline to better hide latencies.
|
||||
* the elements of A and B to be cutlass::half_t
|
||||
(a 16-bit floating-point type),
|
||||
|
||||
* the elements of C to be float, and
|
||||
|
||||
There are few more template variables initialized such as, which threadblock tile of output matrix
|
||||
is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on.
|
||||
* intermediate sums to be accumulated in float.
|
||||
|
||||
These are all put together to create a template variable which describes CUTLASS Implicit GEMM
|
||||
kernel using cutlass::conv::device::ImplicitGemm template.
|
||||
We convey this to the CUTLASS kernel
|
||||
by setting the following template parameters.
|
||||
|
||||
The next step is to initialize physical data, instantiate and initialize CUTLASS kernel and run it.
|
||||
We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come
|
||||
in the way of learning CUTLASS.
|
||||
* alpha and beta: ElementComputeEpilogue = float
|
||||
|
||||
Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS
|
||||
kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64,
|
||||
R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the
|
||||
important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space
|
||||
memory required by the kernel we instantiated. If yes, we create it and pass it along with other
|
||||
arguments created to initialize CUTLASS kernel then, the kernel is launched.
|
||||
* Elements of input tensor A: ElementInputA = cutlass::half_t
|
||||
|
||||
In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to
|
||||
compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel.
|
||||
* Elements of input tensor B: ElementInputB = cutlass::half_t
|
||||
|
||||
* Elements of output tensor C: ElementOutput = float
|
||||
|
||||
* Accumulation type: ElementAccumulator = float
|
||||
|
||||
Next, we describe the layout of the input and output tensors.
|
||||
We convey this to the CUTLASS kernel
|
||||
by setting the following template parameters.
|
||||
|
||||
* Layout of input tensor A: LayoutInputA = TensorNHWC
|
||||
|
||||
* Layout of input tensor B: LayoutInputB = TensorNHWC
|
||||
|
||||
* Layout of output tensor C: LayoutOutput = TensorNHWC
|
||||
|
||||
After that, we set up rules to compute the epilogue.
|
||||
The epilogue in this case is a simple linear combination
|
||||
C = alpha * X + beta * C.
|
||||
Thus, we set the kernel's template parameter EpilogueOp
|
||||
to LinearCombination. LinearCombination itself
|
||||
has template parameters:
|
||||
|
||||
* the element type of the output tensor (ElementOutput),
|
||||
|
||||
* the number of elements per vector memory access (8),
|
||||
|
||||
* the data type of the accumulator (ElementAccumulator),
|
||||
|
||||
* and the data type used to compute the linear combination
|
||||
(ElementComputeEpilogue).
|
||||
|
||||
We then define the tile shapes
|
||||
that each level of the computation uses.
|
||||
We define these as types that encode the tile shapes
|
||||
as compile-time integer values.
|
||||
Each shape expresses the dimensions M x N x K.
|
||||
Here, the letters refer to the dimensions
|
||||
of a matrix-matrix multiply.
|
||||
|
||||
* ThreadblockShape defines the threadblock tile shape
|
||||
as 128 x 128 x 64.
|
||||
|
||||
* WarpShape defines the warp tile shape as 64 x 64 x 64.
|
||||
|
||||
* InstructionShape defines the MMA
|
||||
(matrix multiply-accumulate) operation shape
|
||||
as 16 x 8 x 16.
|
||||
|
||||
These types become template arguments
|
||||
of the kernel properties type
|
||||
cutlass::conv::kernel::DefaultConv2dFprop.
|
||||
The kernel uses these shapes to deduce
|
||||
the number of threads needed per threadblock,
|
||||
the required amount of shared memory,
|
||||
the internal layouts needed to access
|
||||
shared memory without bank conflicts,
|
||||
and many other properties that the kernel needs
|
||||
for good performance.
|
||||
CUTLASS deduces all these properties automatically,
|
||||
so that users don't have to.
|
||||
DefaultConv2dFprop accepts other template parameters
|
||||
that describe things like the target CUDA SM architecture.
|
||||
|
||||
CUTLASS also supports multiple MMA pipelines in a threadblock.
|
||||
An MMA pipeline constitutes the whole process
|
||||
of loading input data from global memory to shared memory,
|
||||
loading data from shared memory to registers,
|
||||
doing matrix multiplication,
|
||||
and storing the result to global memory.
|
||||
The below flow sequence shows a typical MMA multistage pipeline
|
||||
(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h).
|
||||
|
||||
tensor in global memory
|
||||
--cp_async-->
|
||||
tile in shared memory
|
||||
--smem loads-->
|
||||
registers
|
||||
--mma-->
|
||||
registers
|
||||
--global stores-->
|
||||
output to global memory
|
||||
|
||||
On NVIDIA Ampere, the kernel uses `cp_async`
|
||||
to build a multistage software pipeline.
|
||||
This helps it better hide latency.
|
||||
|
||||
At this point, we can define the actual CUTLASS kernel type
|
||||
as the alias ImplicitGemm, a specialization of
|
||||
cutlass::conv::device::ImplicitGemmConvolution.
|
||||
The latter accepts the kernel properties type alias
|
||||
Conv2dFpropKernel as its one template argument.
|
||||
|
||||
This example then sets up a test problem
|
||||
and arguments to the kernel.
|
||||
We use CUTLASS utilities to allocate
|
||||
the input and output tensors
|
||||
and fill them with sample input data.
|
||||
We then create the kernel arguments
|
||||
as an instance of ImplicitGemm::Arguments.
|
||||
The arguments include
|
||||
the problem size (N = 1, H = 64, W = 64, C = 128),
|
||||
filter size (K = 64, R = 3, S = 3, C = 128),
|
||||
padding, strides, dilation, tensors, alpha, beta,
|
||||
and the split k-dimension factor.
|
||||
We also query CUTLASS if the kernel we instantiated
|
||||
requires any memory for scratch space.
|
||||
If yes, we reserve scratch space and pass it along
|
||||
with other arguments to initialize the CUTLASS kernel.
|
||||
|
||||
After lauching the CUTLASS kernel, this example runs
|
||||
a reference convolution kernel (from CUTLASS utilities)
|
||||
to check correctness.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
@ -131,8 +229,8 @@ compare if the output from CUTLASS kernel is same as the reference implicit GEMM
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
// Data types for input and output tensors
|
||||
// and computation between elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
@ -143,39 +241,40 @@ using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
// Whether to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
// SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock tile shape
|
||||
// Threadblock tile shape
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp tile shape
|
||||
// Warp tile shape
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
|
||||
// MMA (Tensor Core instruction, in this case) tile shape
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
// How the kernel schedules threadblocks
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
// Number of pipeline stages to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
// This code section describe iterator algorithm selected is Analytic or Optimized
|
||||
// Which iterator algorithm to use: Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
// The epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
// Kernel properties type
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
@ -193,6 +292,7 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
IteratorAlgorithm
|
||||
>::Kernel;
|
||||
|
||||
// Type of the actual kernel
|
||||
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -230,7 +330,7 @@ struct Options {
|
||||
beta(0),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
// Verify that the problem size is compatible with CUTLASS's convolution implementation
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
@ -256,7 +356,7 @@ struct Options {
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
/// Update input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor4DCoord input_size,
|
||||
cutlass::Tensor4DCoord filter_size) {
|
||||
@ -270,7 +370,7 @@ struct Options {
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
// Parse command-line arguments
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
@ -302,11 +402,11 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
filter_size.c() = input_size.c();
|
||||
filter_size.c() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
@ -320,12 +420,12 @@ struct Options {
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
/// Print an explanation of the command-line arguments
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "16_ampere_tensorop_conv2dfprop example\n\n"
|
||||
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< " This example uses Ampere's Tensor Core operators on F16 data types\n"
|
||||
<< " to compute forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n=<int> Input tensor extent N\n"
|
||||
@ -350,7 +450,7 @@ struct Options {
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
@ -360,19 +460,20 @@ struct Options {
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
/// Compute performance in Gflop/s
|
||||
///
|
||||
/// Gflop/s stands for billions (10^9) of
|
||||
/// floating-point operations per second (Gflop/s).
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
@ -380,14 +481,14 @@ struct Result {
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
static std::ostream& print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
@ -404,7 +505,7 @@ struct Result {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.h() << ","
|
||||
@ -420,8 +521,6 @@ struct Result {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
@ -441,7 +540,7 @@ Result profile_convolution(Options const &options) {
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
// Fill tensor A on host with uniformly distributed random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
@ -449,7 +548,7 @@ Result profile_convolution(Options const &options) {
|
||||
ElementInputA(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
// Fill tensor B on host with uniformly distributed random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
@ -457,7 +556,7 @@ Result profile_convolution(Options const &options) {
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with uniform-distribution random data
|
||||
// Fill tensor C on host with uniformly distributed random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
@ -490,7 +589,7 @@ Result profile_convolution(Options const &options) {
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Construct Conv2dProblemSize with user defined output size
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
@ -501,7 +600,7 @@ Result profile_convolution(Options const &options) {
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
// Construct ImplicitGemm::Argument structure with conv2d
|
||||
// Construct ImplicitGemm::Argument structure with conv2d
|
||||
// problem size, data pointers, and epilogue values
|
||||
typename ImplicitGemm::Arguments arguments{
|
||||
problem_size,
|
||||
@ -539,7 +638,7 @@ Result profile_convolution(Options const &options) {
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on host...\n";
|
||||
|
||||
@ -552,8 +651,7 @@ Result profile_convolution(Options const &options) {
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
|
||||
ElementAccumulator
|
||||
>(
|
||||
problem_size,
|
||||
tensor_a.host_ref(),
|
||||
@ -564,7 +662,7 @@ Result profile_convolution(Options const &options) {
|
||||
options.beta
|
||||
);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
// Check if CUTLASS kernel and reference kernel produced the same output
|
||||
tensor_d.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
@ -589,14 +687,14 @@ Result profile_convolution(Options const &options) {
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "16_ampere_workspace_conv2dfprop_"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
output_workspace
|
||||
<< "Input = \n" << tensor_a.host_view() << "\n\n"
|
||||
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
@ -616,7 +714,7 @@ Result profile_convolution(Options const &options) {
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
@ -632,7 +730,7 @@ Result profile_convolution(Options const &options) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
// Launch a sequence of implicit GEMM operations on the device.
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
@ -652,7 +750,7 @@ Result profile_convolution(Options const &options) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
// Measure elapsed runtime.
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
@ -660,7 +758,7 @@ Result profile_convolution(Options const &options) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
// Print average run time and floating-point throughput (Gflop/s).
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
@ -673,8 +771,6 @@ Result profile_convolution(Options const &options) {
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
@ -701,7 +797,7 @@ int main(int argc, char const **args) {
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
@ -768,5 +864,3 @@ int main(int argc, char const **args) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -470,8 +470,7 @@ Result profile_convolution(Options const &options) {
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
|
||||
ElementAccumulator
|
||||
>(
|
||||
problem_size,
|
||||
tensor_a.host_ref(),
|
||||
|
||||
@ -30,7 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
The example demenstrates how to reduce one of the operands of the GEMM along the k-dimension when
|
||||
The example demonstrates how to reduce one of the operands of the GEMM along the k-dimension when
|
||||
computing GEMM. So the output also contains either a Mx1 or 1XN vector. It only works with Ampere
|
||||
16x8x16 FP16/BF16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor
|
||||
core instructions.
|
||||
|
||||
@ -31,6 +31,7 @@
|
||||
|
||||
cutlass_example_add_executable(
|
||||
24_gemm_grouped
|
||||
gemm_grouped.cu
|
||||
gemm_grouped.cu
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@
|
||||
leading dimensions and problem sizes are stored in arrays in GMEM.
|
||||
|
||||
This differs from "Batched Array" GEMM because the size of each GEMM problem in the Grouped GEMM
|
||||
concept may be distinct.
|
||||
concept may be distinct.
|
||||
|
||||
This benchmark program initializes a workspace with random problem sizes for a given number of
|
||||
groups. Command line options enable overriding M, N, and/or K dimensions with uniform values to
|
||||
@ -186,7 +186,7 @@ struct Options {
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
//
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
@ -216,7 +216,7 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("alignment", alignment, 8);
|
||||
cmd.get_cmd_line_argument("groups", problem_count, 15);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.0f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.0f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.0f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations, 20);
|
||||
cmd.get_cmd_line_argument("streams", cuda_streams, 0);
|
||||
cmd.get_cmd_line_argument("verbose", verbose, false);
|
||||
@ -455,13 +455,13 @@ struct Options {
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = int64_t();
|
||||
|
||||
for (auto const & problem : problem_sizes) {
|
||||
fmas += problem.product();
|
||||
}
|
||||
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
@ -546,7 +546,7 @@ public:
|
||||
template <typename Element>
|
||||
void initialize_tensor(
|
||||
Element *ptr,
|
||||
size_t capacity,
|
||||
size_t capacity,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint32_t seed) {
|
||||
|
||||
@ -578,7 +578,7 @@ public:
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
ptr, capacity, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
@ -589,7 +589,7 @@ public:
|
||||
// Fill with increasing elements
|
||||
cutlass::reference::device::BlockFillSequential(
|
||||
ptr, capacity, Element(1), Element());
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Fill with all 1s
|
||||
@ -674,13 +674,13 @@ public:
|
||||
|
||||
ptr_A.reset(problem_count());
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
|
||||
ptr_B.reset(problem_count());
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
|
||||
ptr_C.reset(problem_count());
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
|
||||
ptr_D.reset(problem_count());
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
@ -712,7 +712,7 @@ public:
|
||||
MatrixCoord extent_A{problem.m(), problem.k()};
|
||||
MatrixCoord extent_B{problem.k(), problem.n()};
|
||||
MatrixCoord extent_C{problem.m(), problem.n()};
|
||||
|
||||
|
||||
cutlass::TensorView<ElementA, LayoutA> view_A(block_A.get() + offset_A.at(i), layout_A, extent_A);
|
||||
cutlass::TensorView<ElementB, LayoutB> view_B(block_B.get() + offset_B.at(i), layout_B, extent_B);
|
||||
cutlass::TensorView<ElementC, LayoutC> view_C(block_C.get() + offset_C.at(i), layout_C, extent_C);
|
||||
@ -724,18 +724,18 @@ public:
|
||||
cutlass::reference::device::GemmComplex<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ElementC, LayoutC,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem,
|
||||
options.alpha,
|
||||
options.alpha,
|
||||
view_A,
|
||||
Gemm::kTransformA,
|
||||
view_B,
|
||||
Gemm::kTransformB,
|
||||
options.beta,
|
||||
view_C,
|
||||
view_Ref_device,
|
||||
options.beta,
|
||||
view_C,
|
||||
view_Ref_device,
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
@ -781,8 +781,8 @@ public:
|
||||
std::cout << "Conventionally executed as " << this->options.problem_bins.size() << " batched GEMMs:\n";
|
||||
for (auto const & bin : this->options.problem_bins) {
|
||||
|
||||
std::cout << " [" << bin_idx << "]: "
|
||||
<< bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k()
|
||||
std::cout << " [" << bin_idx << "]: "
|
||||
<< bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k()
|
||||
<< ", batch count: " << bin.second.size() << "\n";
|
||||
|
||||
++bin_idx;
|
||||
@ -832,7 +832,7 @@ public:
|
||||
|
||||
for (auto const & bin : this->options.problem_bins) {
|
||||
int first_idx = bin.second.front();
|
||||
|
||||
|
||||
bin_problem_sizes.push_back(this->options.problem_sizes.at(first_idx));
|
||||
bin_count.push_back(int32_t(bin.second.size()));
|
||||
|
||||
@ -974,7 +974,7 @@ public:
|
||||
std::cerr << "CUTLASS error on line " << __LINE__ << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
@ -1027,7 +1027,7 @@ public:
|
||||
int last_stream_idx = 0;
|
||||
|
||||
for (int iter = 0; iter < this->options.iterations; ++iter) {
|
||||
|
||||
|
||||
for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) {
|
||||
|
||||
cutlass::gemm::GemmCoord const & problem = bin_problem_sizes[bin_idx];
|
||||
@ -1098,7 +1098,7 @@ public:
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Wait for work to be completed
|
||||
//
|
||||
@ -1129,10 +1129,10 @@ public:
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
|
||||
for (auto stream : cuda_streams) {
|
||||
if (stream) {
|
||||
(void)cudaStreamDestroy(stream);
|
||||
(void)cudaStreamDestroy(stream);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1203,8 +1203,8 @@ public:
|
||||
int tiles = Gemm::problem_tile_count(problem);
|
||||
total_tiles += tiles;
|
||||
|
||||
std::cout << " [" << idx << "]: "
|
||||
<< problem.m() << "-by-" << problem.n() << "-by-" << problem.k()
|
||||
std::cout << " [" << idx << "]: "
|
||||
<< problem.m() << "-by-" << problem.n() << "-by-" << problem.k()
|
||||
<< " (" << tiles << " threadblock tiles)" << "\n";
|
||||
|
||||
++idx;
|
||||
@ -1442,12 +1442,12 @@ int main(int argc, char const **args) {
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) {
|
||||
|
||||
|
||||
//
|
||||
// This example requires an NVIDIA Ampere-architecture GPU.
|
||||
//
|
||||
|
||||
std::cout
|
||||
std::cout
|
||||
<< "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or "
|
||||
<< "later (compute capability 80 or greater).\n";
|
||||
|
||||
@ -1497,9 +1497,9 @@ int main(int argc, char const **args) {
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
|
||||
@ -1519,8 +1519,8 @@ int main(int argc, char const **args) {
|
||||
cutlass::ComplexTransform::kNone,
|
||||
8,
|
||||
ElementOutput, LayoutC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
@ -1531,7 +1531,7 @@ int main(int argc, char const **args) {
|
||||
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
|
||||
// This parameter is passed in at present to match the APIs of other kernels. The parameter
|
||||
// is unused within the kernel.
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
4>::GemmKernel;
|
||||
|
||||
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
|
||||
@ -181,7 +181,7 @@ struct Options {
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_complex_gemm --m=1024 --n=512 \\\n"
|
||||
<< "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm --m=1024 --n=512 \\\n"
|
||||
<< " --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
@ -27,9 +27,9 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
|
||||
cutlass_example_add_executable(
|
||||
29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm
|
||||
29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu
|
||||
29_3xtf32_complex_gemm
|
||||
29_3xtf32_complex_gemm.cu
|
||||
)
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@
|
||||
matrix multiply kernel to verify its correctness.
|
||||
|
||||
The CUTLASS Syrk template is instantiated in the function CutlassSsyrkNN. This is kernel computes
|
||||
the symmetric rank-k update (SYRK) using double-precision doubleing-point arithmetic and assumes
|
||||
the symmetric rank-k update (SYRK) using double-precision floating-point arithmetic and assumes
|
||||
all matrices have column-major layout.
|
||||
|
||||
The threadblock tile size is chosen as 16x32x16 which offers good performance for large matrices.
|
||||
|
||||
@ -34,7 +34,7 @@
|
||||
matrix multiply kernel to verify its correctness.
|
||||
|
||||
The CUTLASS Trmm template is instantiated in the function CutlassStrmmNN. This is kernel computes
|
||||
the triangular matrix product (TRMM) using double-precision doubleing-point arithmetic and assumes
|
||||
the triangular matrix product (TRMM) using double-precision floating-point arithmetic and assumes
|
||||
all matrices have column-major layout.
|
||||
|
||||
The threadblock tile size is chosen as 64x64x16 which offers good performance for large matrices.
|
||||
|
||||
@ -578,9 +578,21 @@ public:
|
||||
|
||||
int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
cudaError_t result;
|
||||
|
||||
if (gemm_smem_size >= (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
gemm_smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<gemm_grid, gemm_block, gemm_smem_size, stream>>>(params_.gemm);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return cutlass::Status::kErrorInternal;
|
||||
|
||||
@ -316,7 +316,11 @@ int run(Options &options) {
|
||||
// <- Fill tensor_b_indices on host with unique random integers
|
||||
std::vector<int> to_fill(problem_size.n()) ; // vector with ints.
|
||||
std::iota (std::begin(to_fill), std::end(to_fill), 0); // Fill with 0, 1, ...., problem_size.n()
|
||||
std::random_shuffle(to_fill.begin(), to_fill.end());
|
||||
{ // std::random_shuffle was deprecated in C++14 and removed in C++17
|
||||
std::random_device make_seed;
|
||||
std::mt19937 source_of_randomness(make_seed());
|
||||
std::shuffle(to_fill.begin(), to_fill.end(), source_of_randomness);
|
||||
}
|
||||
memcpy(tensor_indices.host_data(), to_fill.data(), options.index_size * sizeof(int));
|
||||
|
||||
// Copy data from host to GPU
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
510
examples/39_gemm_permute/layouts.h
Normal file
510
examples/39_gemm_permute/layouts.h
Normal file
@ -0,0 +1,510 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines additional layout functions used in Permute GEMM example to simplify
|
||||
computing reference permutations of 4/5D tensors when source data is column-major.
|
||||
*/
|
||||
#pragma once
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include "assert.h"
|
||||
#endif
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace layout {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Mapping function for 4-D CWHN tensors.
|
||||
class TensorCWHN {
|
||||
public:
|
||||
/// Logical rank of tensor
|
||||
static int const kRank = 4;
|
||||
|
||||
/// Rank of stride vector
|
||||
static int const kStrideRank = 3;
|
||||
|
||||
/// Index type used for coordinates
|
||||
using Index = int32_t;
|
||||
|
||||
/// Long index type used for offsets
|
||||
using LongIndex = int64_t;
|
||||
|
||||
/// Logical coordinate (n, h, w, c)
|
||||
using TensorCoord = Tensor4DCoord;
|
||||
|
||||
/// Stride vector
|
||||
using Stride = Coord<kStrideRank>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Stride data member - [n, hn, whn]
|
||||
Stride stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCWHN(Stride const &stride = Stride(0)): stride_(stride) { }
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCWHN(
|
||||
typename Stride::Index stride_h, ///< number of elements between adjacent N coordinates
|
||||
typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates
|
||||
typename Stride::Index stride_c ///< number of elements between adjacent W coordinates
|
||||
):
|
||||
stride_(make_Coord(stride_h, stride_w, stride_c)) { }
|
||||
|
||||
/// Constructor
|
||||
// Once convolutions implement 64b stride this ctor can be deleted
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCWHN(Coord<kStrideRank, LongIndex> const &stride):
|
||||
stride_(make_Coord(
|
||||
static_cast<typename Stride::Index>(stride[0]),
|
||||
static_cast<typename Stride::Index>(stride[1]),
|
||||
static_cast<typename Stride::Index>(stride[2]))
|
||||
) { }
|
||||
|
||||
/// Helper returns a layout to a tightly packed WCNH tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static TensorCWHN packed(TensorCoord const &extent) {
|
||||
return TensorCWHN(
|
||||
make_Coord(
|
||||
extent.n(),
|
||||
extent.h() * extent.n(),
|
||||
extent.w() * extent.h() * extent.n()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Returns the offset of a coordinate (n, h, w, c) in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(TensorCoord const &coord) const {
|
||||
return coord.n() +
|
||||
LongIndex(stride_[0] * coord.h()) +
|
||||
LongIndex(stride_[1] * coord.w()) +
|
||||
LongIndex(stride_[2] * coord.c());
|
||||
}
|
||||
|
||||
/// Returns the offset of a pitchlinear coordinate in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(PitchLinearCoord coord) const {
|
||||
return coord.contiguous() + LongIndex(coord.strided() * stride_[2]);
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride & stride() {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
/// Compute the number of contiguous elements needed to store a tensor with the given size
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex capacity(TensorCoord const &extent) const {
|
||||
// it does not make sense if the extent is larger than stride
|
||||
// and we could not rely on the capacity calculation in such cases
|
||||
// we could move this checkers to debug code only
|
||||
if ((extent.n() > stride_[0])
|
||||
|| (extent.h() * stride_[0] > stride_[1])
|
||||
|| (extent.w() * stride_[1] > stride_[2])) {
|
||||
assert(0);
|
||||
}
|
||||
return extent.c() * stride_[2];
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Mapping function for 4-D NHCW tensors.
|
||||
class TensorNHCW {
|
||||
public:
|
||||
/// Logical rank of tensor
|
||||
static int const kRank = 4;
|
||||
|
||||
/// Rank of stride vector
|
||||
static int const kStrideRank = 3;
|
||||
|
||||
/// Index type used for coordinates
|
||||
using Index = int32_t;
|
||||
|
||||
/// Long index type used for offsets
|
||||
using LongIndex = int64_t;
|
||||
|
||||
/// Logical coordinate (n, h, w, c)
|
||||
using TensorCoord = Tensor4DCoord;
|
||||
|
||||
/// Stride vector
|
||||
using Stride = Coord<kStrideRank>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Stride data member - [w, cw, hcw]
|
||||
Stride stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorNHCW(Stride const &stride = Stride(0)): stride_(stride) { }
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorNHCW(
|
||||
typename Stride::Index stride_c, ///< number of elements between adjacent C coordinates
|
||||
typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates
|
||||
typename Stride::Index stride_n ///< number of elements between adjacent N coordinates
|
||||
):
|
||||
stride_(make_Coord(stride_c, stride_h, stride_n)) { }
|
||||
|
||||
/// Constructor
|
||||
// Once convolutions implement 64b stride this ctor can be deleted
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorNHCW(Coord<kStrideRank, LongIndex> const &stride):
|
||||
stride_(make_Coord(
|
||||
static_cast<typename Stride::Index>(stride[0]),
|
||||
static_cast<typename Stride::Index>(stride[1]),
|
||||
static_cast<typename Stride::Index>(stride[2]))
|
||||
) { }
|
||||
|
||||
/// Helper returns a layout to a tightly packed WCNH tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static TensorNHCW packed(TensorCoord const &extent) {
|
||||
return TensorNHCW(
|
||||
make_Coord(
|
||||
extent.w(),
|
||||
extent.c() * extent.w(),
|
||||
extent.h() * extent.c() * extent.w()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Returns the offset of a coordinate (n, h, w, c) in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(TensorCoord const &coord) const {
|
||||
return coord.w() +
|
||||
LongIndex(stride_[0] * coord.c()) +
|
||||
LongIndex(stride_[1] * coord.h()) +
|
||||
LongIndex(stride_[2] * coord.n());
|
||||
}
|
||||
|
||||
/// Returns the offset of a pitchlinear coordinate in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(PitchLinearCoord coord) const {
|
||||
return coord.contiguous() + LongIndex(coord.strided() * stride_[2]);
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride & stride() {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
/// Compute the number of contiguous elements needed to store a tensor with the given size
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex capacity(TensorCoord const &extent) const {
|
||||
// it does not make sense if the extent is larger than stride
|
||||
// and we could not rely on the capacity calculation in such cases
|
||||
// we could move this checkers to debug code only
|
||||
if ((extent.w() > stride_[0])
|
||||
|| (extent.c() * stride_[0] > stride_[1])
|
||||
|| (extent.h() * stride_[1] > stride_[2])) {
|
||||
assert(0);
|
||||
}
|
||||
return extent.n() * stride_[2];
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Mapping function for 4-D NHCW tensors.
|
||||
class TensorNCWH {
|
||||
public:
|
||||
/// Logical rank of tensor
|
||||
static int const kRank = 4;
|
||||
|
||||
/// Rank of stride vector
|
||||
static int const kStrideRank = 3;
|
||||
|
||||
/// Index type used for coordinates
|
||||
using Index = int32_t;
|
||||
|
||||
/// Long index type used for offsets
|
||||
using LongIndex = int64_t;
|
||||
|
||||
/// Logical coordinate (n, h, w, c)
|
||||
using TensorCoord = Tensor4DCoord;
|
||||
|
||||
/// Stride vector
|
||||
using Stride = Coord<kStrideRank>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Stride data member - [h, wh, cwh]
|
||||
Stride stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorNCWH(Stride const &stride = Stride(0)): stride_(stride) { }
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorNCWH(
|
||||
typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates
|
||||
typename Stride::Index stride_c, ///< number of elements between adjacent H coordinates
|
||||
typename Stride::Index stride_n ///< number of elements between adjacent N coordinates
|
||||
):
|
||||
stride_(make_Coord(stride_w, stride_c, stride_n)) { }
|
||||
|
||||
/// Constructor
|
||||
// Once convolutions implement 64b stride this ctor can be deleted
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorNCWH(Coord<kStrideRank, LongIndex> const &stride):
|
||||
stride_(make_Coord(
|
||||
static_cast<typename Stride::Index>(stride[0]),
|
||||
static_cast<typename Stride::Index>(stride[1]),
|
||||
static_cast<typename Stride::Index>(stride[2]))
|
||||
) { }
|
||||
|
||||
/// Helper returns a layout to a tightly packed WCNH tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static TensorNCWH packed(TensorCoord const &extent) {
|
||||
return TensorNCWH(
|
||||
make_Coord(
|
||||
extent.h(),
|
||||
extent.w() * extent.h(),
|
||||
extent.c() * extent.w() * extent.h()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Returns the offset of a coordinate (n, h, w, c) in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(TensorCoord const &coord) const {
|
||||
return coord.h() +
|
||||
LongIndex(stride_[0] * coord.w()) +
|
||||
LongIndex(stride_[1] * coord.c()) +
|
||||
LongIndex(stride_[2] * coord.n());
|
||||
}
|
||||
|
||||
/// Returns the offset of a pitchlinear coordinate in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(PitchLinearCoord coord) const {
|
||||
return coord.contiguous() + LongIndex(coord.strided() * stride_[2]);
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride & stride() {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
/// Compute the number of contiguous elements needed to store a tensor with the given size
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex capacity(TensorCoord const &extent) const {
|
||||
// it does not make sense if the extent is larger than stride
|
||||
// and we could not rely on the capacity calculation in such cases
|
||||
// we could move this checkers to debug code only
|
||||
if ((extent.h() > stride_[0])
|
||||
|| (extent.w() * stride_[0] > stride_[1])
|
||||
|| (extent.c() * stride_[1] > stride_[2])) {
|
||||
assert(0);
|
||||
}
|
||||
return extent.n() * stride_[2];
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Mapping function for 5-D CWHDN tensors.
|
||||
class TensorCWHDN {
|
||||
public:
|
||||
/// Logical rank of tensor
|
||||
static int const kRank = 5;
|
||||
|
||||
/// Rank of stride vector
|
||||
static int const kStrideRank = 4;
|
||||
|
||||
/// Index type used for coordinates
|
||||
using Index = int32_t;
|
||||
|
||||
/// Long index type used for offsets
|
||||
using LongIndex = int64_t;
|
||||
|
||||
/// Logical coordinate (n, d, h, w, c)
|
||||
using TensorCoord = Tensor5DCoord;
|
||||
|
||||
/// Stride vector
|
||||
using Stride = Coord<kStrideRank>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Stride data member - [n, dn, hdn, whdn]
|
||||
Stride stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCWHDN(Stride const &stride = Stride(0)): stride_(stride) { }
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCWHDN(
|
||||
typename Stride::Index n,
|
||||
typename Stride::Index dn,
|
||||
typename Stride::Index hdn,
|
||||
typename Stride::Index whdn):
|
||||
stride_(make_Coord(n, dn, hdn, whdn)) { }
|
||||
|
||||
/// Constructor
|
||||
// Once convolutions implement 64b stride this ctor can be deleted
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCWHDN(Coord<kStrideRank, LongIndex> const &stride):
|
||||
stride_(make_Coord(
|
||||
static_cast<typename Stride::Index>(stride[0]),
|
||||
static_cast<typename Stride::Index>(stride[1]),
|
||||
static_cast<typename Stride::Index>(stride[2]),
|
||||
static_cast<typename Stride::Index>(stride[3]))
|
||||
) { }
|
||||
|
||||
/// Helper returns a layout to a tightly packed CWHDN tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static TensorCWHDN packed(TensorCoord const &extent) {
|
||||
return TensorCWHDN(
|
||||
make_Coord(
|
||||
extent.n(),
|
||||
extent.d() * extent.n(),
|
||||
extent.h() * extent.d() * extent.n(),
|
||||
extent.w() * extent.h() * extent.d() * extent.n()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Returns the offset of a coordinate (n, d, h, w, c) in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(TensorCoord const &coord) const {
|
||||
return coord.n() +
|
||||
LongIndex(stride_[0] * coord.d()) +
|
||||
LongIndex(stride_[1] * coord.h()) +
|
||||
LongIndex(stride_[2] * coord.w()) +
|
||||
LongIndex(stride_[3] * coord.c());
|
||||
}
|
||||
|
||||
/// Returns the offset of a pitchlinear coordinate in linear memory.
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(PitchLinearCoord coord) const {
|
||||
return coord.contiguous() + LongIndex(coord.strided() * stride_[3]);
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
/// Returns the stride of the layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride & stride() {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
/// Compute the number of contiguous elements needed to store a tensor with the given size
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex capacity(TensorCoord const &extent) const {
|
||||
// it does not make sense if the extent is larger than stride
|
||||
// and we could not rely on the capacity calculation in such cases
|
||||
// we could move this checkers to debug code only
|
||||
if ((extent.n() > stride_[0])
|
||||
|| (extent.d() * stride_[0] > stride_[1])
|
||||
|| (extent.h() * stride_[1] > stride_[2])
|
||||
|| (extent.w() * stride_[2] > stride_[3])) {
|
||||
assert(0);
|
||||
}
|
||||
return extent.c() * stride_[3];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace layout
|
||||
} // namespace cutlass
|
||||
344
examples/39_gemm_permute/permute_info.h
Normal file
344
examples/39_gemm_permute/permute_info.h
Normal file
@ -0,0 +1,344 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Contains additional metadata about layout permute functions used in the example.
|
||||
*/
|
||||
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
/// Additional permutation metadata to facilitate testing/printing
|
||||
template<typename PermuteLayout>
|
||||
struct PermuteInfo;
|
||||
|
||||
/// Specialization for default case (no permute). Other specializations must follow this template.
|
||||
template<>
|
||||
struct PermuteInfo<cutlass::layout::NoPermute> {
|
||||
|
||||
/// Whether this is a BMM or GEMM permutation (NoPermute can actually be either)
|
||||
static bool constexpr kBatched = false;
|
||||
|
||||
/// Minimal divisor for row extent
|
||||
static int constexpr kRowFactor = 1;
|
||||
|
||||
/// Minimum divisor for column extent
|
||||
static int constexpr kColumnFactor = 1;
|
||||
|
||||
/// Minimum divisor for batch size dimension
|
||||
static int constexpr kBatchFactor = 1;
|
||||
|
||||
/// Tensor layout used in permutation operation
|
||||
using Layout = cutlass::layout::PackedVectorLayout;
|
||||
|
||||
static std::string name() {
|
||||
return "NoPermute";
|
||||
}
|
||||
|
||||
/// User-friendly description of the permute operation
|
||||
static std::string desc() {
|
||||
return "no permutation";
|
||||
}
|
||||
|
||||
/// Infer original higher-rank tensor shape from GEMM/BMM matrix extents.
|
||||
/// For direct (output) permutations, must be a simple reshape of extent.
|
||||
/// For inverse (input) permutations, must return shape *before* permute operation.
|
||||
/// In case of NoPermute, simply use a linear (rank 1) view of the memory
|
||||
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
|
||||
return Layout::TensorCoord(extent.row() * extent.column() * batch_count);
|
||||
}
|
||||
|
||||
/// Compute the permuted higher-rank tensor shape from the original shape.
|
||||
static Layout::TensorCoord permute(Layout::TensorCoord const &s) {
|
||||
return s;
|
||||
}
|
||||
};
|
||||
|
||||
template<int D1>
|
||||
struct PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>> {
|
||||
|
||||
static bool constexpr kBatched = true;
|
||||
static int constexpr kRowFactor = 1;
|
||||
static int constexpr kColumnFactor = 1;
|
||||
static int constexpr kBatchFactor = D1;
|
||||
|
||||
using Layout = cutlass::layout::TensorNHWC;
|
||||
|
||||
static std::string name() {
|
||||
return "Tensor4DPermuteBMM0213<" + std::to_string(D1) + ">";
|
||||
}
|
||||
|
||||
static std::string desc() {
|
||||
return "batched GEMM permutation [0, 2, 1, 3]";
|
||||
}
|
||||
|
||||
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
|
||||
int D0 = batch_count / D1;
|
||||
int D2 = extent.row();
|
||||
int D3 = extent.column();
|
||||
return {D0, D1, D2, D3};
|
||||
}
|
||||
|
||||
static Layout::TensorCoord permute(Layout::TensorCoord const &s) {
|
||||
return {s[0], s[2], s[1], s[3]};
|
||||
}
|
||||
};
|
||||
|
||||
template<int D1>
|
||||
struct PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0213RowMajorInverse<D1>>
|
||||
: public PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>> {
|
||||
|
||||
static bool constexpr kBatched = true;
|
||||
static int constexpr kRowFactor = 1;
|
||||
static int constexpr kColumnFactor = D1;
|
||||
static int constexpr kBatchFactor = 1;
|
||||
|
||||
using Base = PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>>;
|
||||
using Layout = typename Base::Layout;
|
||||
|
||||
static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
|
||||
int D0 = batch_count;
|
||||
int D2 = extent.row();
|
||||
int D3 = extent.column() / D1;
|
||||
return {D0, D1, D2, D3};
|
||||
}
|
||||
};
|
||||
|
||||
template<int D1>
|
||||
struct PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>> {
|
||||
|
||||
static bool constexpr kBatched = true;
|
||||
static int constexpr kRowFactor = 1;
|
||||
static int constexpr kColumnFactor = 1;
|
||||
static int constexpr kBatchFactor = D1;
|
||||
|
||||
using Layout = cutlass::layout::TensorNHCW;
|
||||
|
||||
static std::string name() {
|
||||
return "Tensor4DPermuteBMM0321<" + std::to_string(D1) + ">";
|
||||
}
|
||||
|
||||
static std::string desc() {
|
||||
return "batched GEMM permutation [0, 3, 2, 1]";
|
||||
}
|
||||
|
||||
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
|
||||
int D0 = batch_count / D1;
|
||||
int D2 = extent.row();
|
||||
int D3 = extent.column();
|
||||
return {D0, D1, D2, D3};
|
||||
}
|
||||
|
||||
static Layout::TensorCoord permute(Layout::TensorCoord const &s) {
|
||||
return {s[0], s[3], s[2], s[1]};
|
||||
}
|
||||
};
|
||||
|
||||
template<int D1>
|
||||
struct PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajorInverse<D1>>
|
||||
: public PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>> {
|
||||
|
||||
static bool constexpr kBatched = true;
|
||||
static int constexpr kRowFactor = D1;
|
||||
static int constexpr kColumnFactor = 1;
|
||||
static int constexpr kBatchFactor = 1;
|
||||
|
||||
using Base = PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>>;
|
||||
using Layout = typename Base::Layout;
|
||||
|
||||
static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
|
||||
int D0 = batch_count;
|
||||
int D2 = extent.row() / D1;
|
||||
int D3 = extent.column();
|
||||
return {D0, D1, D2, D3};
|
||||
}
|
||||
};
|
||||
|
||||
template<int D1, int D2>
|
||||
struct PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajor<D1, D2>> {
|
||||
|
||||
static bool constexpr kBatched = false;
|
||||
static int constexpr kRowFactor = D1;
|
||||
static int constexpr kColumnFactor = D2;
|
||||
static int constexpr kBatchFactor = 1;
|
||||
|
||||
using Layout = cutlass::layout::TensorNHWC;
|
||||
|
||||
static std::string name() {
|
||||
return "Tensor4DPermute0213<" + std::to_string(D1) + "," + std::to_string(D2) + ">";
|
||||
}
|
||||
|
||||
static std::string desc() {
|
||||
return "normal GEMM permutation [0, 2, 1, 3]";
|
||||
}
|
||||
|
||||
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
|
||||
int D0 = extent.row() / D1;
|
||||
int D3 = extent.column() / D2;
|
||||
return {D0, D1, D2, D3};
|
||||
}
|
||||
|
||||
static Layout::TensorCoord permute(Layout::TensorCoord const &s) {
|
||||
return {s[0], s[2], s[1], s[3]};
|
||||
}
|
||||
};
|
||||
|
||||
template<int D1, int D2>
|
||||
struct PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajorInverse<D1, D2>>
|
||||
: public PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajor<D1, D2>> {
|
||||
|
||||
static bool constexpr kBatched = false;
|
||||
static int constexpr kRowFactor = D2;
|
||||
static int constexpr kColumnFactor = D1;
|
||||
static int constexpr kBatchFactor = 1;
|
||||
|
||||
using Base = PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajor<D1, D2>>;
|
||||
using Layout = typename Base::Layout;
|
||||
|
||||
static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
|
||||
int D0 = extent.row() / D2;
|
||||
int D3 = extent.column() / D1;
|
||||
return {D0, D1, D2, D3};
|
||||
}
|
||||
};
|
||||
|
||||
template<int D1, int D2>
|
||||
struct PermuteInfo<cutlass::layout::Tensor4DPermute0213ColumnMajor<D1, D2>>
|
||||
: public PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajor<D1, D2>> {
|
||||
using Layout = cutlass::layout::TensorCWHN;
|
||||
};
|
||||
|
||||
template<int D1, int D2>
|
||||
struct PermuteInfo<cutlass::layout::Tensor4DPermute0213ColumnMajorInverse<D1, D2>>
|
||||
: public PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajorInverse<D1, D2>> {
|
||||
using Layout = cutlass::layout::TensorCWHN;
|
||||
};
|
||||
|
||||
template<int T1, int T2, int T3>
|
||||
struct PermuteInfo<cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>> {
|
||||
|
||||
static bool constexpr kBatched = false;
|
||||
static int constexpr kRowFactor = T1;
|
||||
static int constexpr kColumnFactor = T2 * T3;
|
||||
static int constexpr kBatchFactor = 1;
|
||||
|
||||
using Layout = cutlass::layout::TensorNDHWC;
|
||||
|
||||
static std::string name() {
|
||||
return "Tensor5DPermute20314<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">";
|
||||
}
|
||||
|
||||
static std::string desc() {
|
||||
return "normal GEMM permutation [2, 0, 3, 1, 4]";
|
||||
}
|
||||
|
||||
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count)
|
||||
{
|
||||
int const T0 = extent.row() / T1;
|
||||
int const T4 = extent.column() / (T2 * T3);
|
||||
return {T0, T1, T2, T3, T4};
|
||||
}
|
||||
|
||||
static Layout::TensorCoord permute(Layout::TensorCoord const &s)
|
||||
{
|
||||
return {s[2], s[0], s[3], s[1], s[4]};
|
||||
}
|
||||
};
|
||||
|
||||
template<int T1, int T2, int T3>
|
||||
struct PermuteInfo<cutlass::layout::Tensor5DPermute20314RowMajorInverse<T1, T2, T3>>
|
||||
: public PermuteInfo<cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>> {
|
||||
|
||||
static bool constexpr kBatched = false;
|
||||
static int constexpr kRowFactor = T2;
|
||||
static int constexpr kColumnFactor = T1 * T3;
|
||||
static int constexpr kBatchFactor = 1;
|
||||
|
||||
using Base = PermuteInfo<cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>>;
|
||||
using Layout = typename Base::Layout;
|
||||
|
||||
static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
|
||||
int const T0 = extent.row() / T2;
|
||||
int const T4 = extent.column() / (T1 * T3);
|
||||
return {T0, T1, T2, T3, T4};
|
||||
}
|
||||
};
|
||||
|
||||
template<int T1, int T2, int T3>
|
||||
struct PermuteInfo<cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>> {
|
||||
|
||||
static bool constexpr kBatched = false;
|
||||
static int constexpr kRowFactor = T1;
|
||||
static int constexpr kColumnFactor = T2 * T3;
|
||||
static int constexpr kBatchFactor = 1;
|
||||
|
||||
using Layout = cutlass::layout::TensorCWHDN;
|
||||
|
||||
static std::string name() {
|
||||
return "Tensor5DPermute02413<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">";
|
||||
}
|
||||
|
||||
static std::string desc() {
|
||||
return "normal GEMM permutation [0, 2, 4, 1, 3]";
|
||||
}
|
||||
|
||||
using Coord = cutlass::Tensor5DCoord;
|
||||
|
||||
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count)
|
||||
{
|
||||
int const T0 = extent.row() / T1;
|
||||
int const T4 = extent.column() / (T2 * T3);
|
||||
return {T0, T1, T2, T3, T4};
|
||||
}
|
||||
|
||||
static Layout::TensorCoord permute(Layout::TensorCoord const &s)
|
||||
{
|
||||
return {s[0], s[2], s[4], s[1], s[3]};
|
||||
}
|
||||
};
|
||||
|
||||
template<int T1, int T2, int T3>
|
||||
struct PermuteInfo<cutlass::layout::Tensor5DPermute02413ColumnMajorInverse<T1, T2, T3>>
|
||||
: public PermuteInfo<cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>> {
|
||||
|
||||
static bool constexpr kBatched = false;
|
||||
static int constexpr kRowFactor = T2;
|
||||
static int constexpr kColumnFactor = T1 * T3;
|
||||
static int constexpr kBatchFactor = 1;
|
||||
|
||||
using Base = PermuteInfo<cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>>;
|
||||
using Layout = typename Base::Layout;
|
||||
|
||||
static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
|
||||
int const T0 = extent.row() / T2;
|
||||
int const T4 = extent.column() / (T1 * T3);
|
||||
return {T0, T1, T2, T3, T4};
|
||||
}
|
||||
};
|
||||
@ -1,22 +1,4 @@
|
||||
# CUTLASS Python Interface Examples
|
||||
This directory contains examples of using CUTLASS's Python interface. It consists of two types of examples:
|
||||
* _Basic examples_: minimal examples that illustrate how to set up GEMMs, convolutions, and grouped GEMM operations
|
||||
* [_Customizable examples_](customizable): examples that allow one to specify a variety of template parameters for the given kernel
|
||||
# PyCUTLASS Examples
|
||||
|
||||
## Setting up the Python interface
|
||||
Please follow the instructions [here](/tools/library/scripts/pycutlass/README.md#installation) to set up the Python API.
|
||||
|
||||
## Running examples
|
||||
Each of the basic examples can be run as follows:
|
||||
```shell
|
||||
# Run the GEMM example
|
||||
python gemm.py
|
||||
|
||||
# Run the Conv2d example
|
||||
python conv2d.py
|
||||
|
||||
# Run the grouped GEMM example
|
||||
python gemm_grouped.py
|
||||
```
|
||||
|
||||
To run the customizable examples, refer to the README in the [customizable](customizable) directory.
|
||||
This directory contains deprecated examples for PyCUTLASS, a precursor to the CUTLASS Python interface.
|
||||
For examples of using CUTLASS's actively-maintained Pythonic interface, see the [examples/python](/examples/python) directory.
|
||||
|
||||
@ -33,15 +33,20 @@
|
||||
Basic example of using the CUTLASS Python interface to run a 2d convolution
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import sys
|
||||
print("This example is deprecated. Please see examples/python for examples of using "
|
||||
"the CUTLASS Python interface.")
|
||||
sys.exit(0)
|
||||
|
||||
import cutlass
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import cutlass_bindings
|
||||
import cutlass.backend as pycutlass
|
||||
from cutlass.backend import *
|
||||
from cutlass.backend.utils.reference_model import Conv2dReferenceModule
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -76,11 +81,11 @@ pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
pycutlass.compiler.nvcc()
|
||||
|
||||
# Set up A, B, C and accumulator
|
||||
A = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment)
|
||||
B = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment)
|
||||
C = TensorDescription(cutlass.float32, cutlass.TensorNHWC, alignment)
|
||||
element_acc = cutlass.float32
|
||||
element_epilogue = cutlass.float32
|
||||
A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment)
|
||||
B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment)
|
||||
C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.TensorNHWC, alignment)
|
||||
element_acc = cutlass_bindings.float32
|
||||
element_epilogue = cutlass_bindings.float32
|
||||
|
||||
# Select instruction shape based on the Tensor Core instructions supported
|
||||
# by the device on which we are running
|
||||
@ -89,12 +94,14 @@ if cc == 70:
|
||||
elif cc == 75:
|
||||
instruction_shape = [16, 8, 8]
|
||||
else:
|
||||
# Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used)
|
||||
cc = 80
|
||||
instruction_shape = [16, 8, 16]
|
||||
|
||||
math_inst = MathInstruction(
|
||||
instruction_shape,
|
||||
A.element, B.element, element_acc,
|
||||
cutlass.OpClass.TensorOp,
|
||||
cutlass_bindings.OpClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
@ -108,8 +115,8 @@ tile_description = TileDescription(
|
||||
epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=cutlass.conv.Operator.fprop,
|
||||
iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
|
||||
conv_kind=cutlass_bindings.conv.Operator.fprop,
|
||||
iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized,
|
||||
arch=cc, tile_description=tile_description,
|
||||
A=A, B=B, C=C, stride_support=StrideSupport.Strided,
|
||||
epilogue_functor=epilogue_functor
|
||||
@ -125,20 +132,20 @@ pycutlass.compiler.add_module(operations)
|
||||
|
||||
# Randomly initialize tensors
|
||||
|
||||
problem_size = cutlass.conv.Conv2dProblemSize(
|
||||
cutlass.Tensor4DCoord(args.n, args.h, args.c, args.w),
|
||||
cutlass.Tensor4DCoord(args.k, args.r, args.s, args.c),
|
||||
cutlass.Tensor4DCoord(0, 0, 0, 0), # Padding
|
||||
cutlass.MatrixCoord(1, 1), # Strides
|
||||
cutlass.MatrixCoord(1, 1), # Dilation
|
||||
cutlass.conv.Mode.cross_correlation,
|
||||
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
|
||||
cutlass_bindings.Tensor4DCoord(args.n, args.h, args.c, args.w),
|
||||
cutlass_bindings.Tensor4DCoord(args.k, args.r, args.s, args.c),
|
||||
cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), # Padding
|
||||
cutlass_bindings.MatrixCoord(1, 1), # Strides
|
||||
cutlass_bindings.MatrixCoord(1, 1), # Dilation
|
||||
cutlass_bindings.conv.Mode.cross_correlation,
|
||||
1, # Split k slices
|
||||
1 # Groups
|
||||
)
|
||||
|
||||
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size)
|
||||
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size)
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size)
|
||||
tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size)
|
||||
tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size)
|
||||
tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size)
|
||||
|
||||
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5))
|
||||
tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5))
|
||||
|
||||
@ -165,28 +165,3 @@ Example 7: GELU
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
|
||||
```
|
||||
### Epilogue Visitor Tree
|
||||
Example 1:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2:
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 3:
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 4:
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 5:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
|
||||
```
|
||||
Example 6:
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
|
||||
```
|
||||
|
||||
@ -29,13 +29,18 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import numpy as np
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.conv2d_operation import *
|
||||
from pycutlass.utils import reference_model
|
||||
from pycutlass.utils.device import device_cc
|
||||
|
||||
import sys
|
||||
print("This example is deprecated. Please see examples/python for examples of using "
|
||||
"the CUTLASS Python interface.")
|
||||
sys.exit(0)
|
||||
|
||||
import numpy as np
|
||||
import cutlass.backend as pycutlass
|
||||
from cutlass.backend import *
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
from cutlass.backend.conv2d_operation import *
|
||||
from cutlass.backend.utils.reference_model import Conv2dReferenceModule
|
||||
import torch.nn.functional as F
|
||||
|
||||
import argparse
|
||||
@ -62,7 +67,7 @@ parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
parser.add_argument('-op', "--opcode", default="Simt", type=str,
|
||||
choices=["Simt", 'TensorOp'],
|
||||
help='This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
|
||||
@ -156,12 +161,12 @@ pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
element_a = getattr(cutlass_bindings, args.element_a)
|
||||
element_b = getattr(cutlass_bindings, args.element_b)
|
||||
element_c = getattr(cutlass_bindings, args.element_c)
|
||||
element_acc = getattr(cutlass_bindings, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
opclass = getattr(cutlass_bindings.OpClass, args.opcode)
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
@ -173,9 +178,9 @@ tile_description = TileDescription(
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
layout_a = getattr(cutlass_bindings, args.layout_a)
|
||||
layout_b = getattr(cutlass_bindings, args.layout_b)
|
||||
layout_c = getattr(cutlass_bindings, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
@ -189,7 +194,7 @@ C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
element_epilogue = getattr(cutlass_bindings, args.element_epilogue)
|
||||
if (args.activation_function == "identity"
|
||||
or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)):
|
||||
#
|
||||
@ -200,10 +205,10 @@ else:
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, args.iterator_algorithm)
|
||||
swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor)
|
||||
stride_support = getattr(StrideSupport, args.stride_support)
|
||||
conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
|
||||
conv_kind = getattr(cutlass_bindings.conv.Operator, args.conv_kind)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
|
||||
@ -226,7 +231,7 @@ if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
|
||||
shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment),
|
||||
C=C, element_accumulator=element_acc,
|
||||
element_compute=element_epilogue,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
@ -236,34 +241,34 @@ if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
|
||||
pycutlass.compiler.add_module(operations)
|
||||
|
||||
problem_size = cutlass.conv.Conv2dProblemSize(
|
||||
cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]),
|
||||
cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]),
|
||||
cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]),
|
||||
cutlass.MatrixCoord(args.stride[0], args.stride[1]),
|
||||
cutlass.MatrixCoord(args.dilation[0], args.dilation[1]),
|
||||
cutlass.conv.Mode.cross_correlation,
|
||||
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
|
||||
cutlass_bindings.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]),
|
||||
cutlass_bindings.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]),
|
||||
cutlass_bindings.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]),
|
||||
cutlass_bindings.MatrixCoord(args.stride[0], args.stride[1]),
|
||||
cutlass_bindings.MatrixCoord(args.dilation[0], args.dilation[1]),
|
||||
cutlass_bindings.conv.Mode.cross_correlation,
|
||||
args.split_k_slices, 1
|
||||
)
|
||||
|
||||
|
||||
# User-provide inputs
|
||||
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
|
||||
tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
|
||||
tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
if args.bias:
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent(
|
||||
tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
|
||||
conv_kind, problem_size
|
||||
).at(3)
|
||||
else:
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
|
||||
tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
|
||||
tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size(
|
||||
tensor_D_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
|
||||
@ -288,12 +293,12 @@ arguments = Conv2dArguments(
|
||||
operation=operation, problem_size=problem_size, A=tensor_A,
|
||||
B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
|
||||
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
|
||||
split_k_mode=getattr(cutlass_bindings.conv.SplitKMode, args.split_k_mode),
|
||||
split_k_slices=problem_size.split_k_slices
|
||||
)
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size)
|
||||
implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size)
|
||||
reduction_arguments = ReductionArguments(
|
||||
reduction_operation,
|
||||
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
|
||||
|
||||
@ -29,13 +29,18 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import numpy as np
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import cutlass
|
||||
from bfloat16 import bfloat16
|
||||
|
||||
import sys
|
||||
print("This example is deprecated. Please see examples/python for examples of using "
|
||||
"the CUTLASS Python interface.")
|
||||
sys.exit(0)
|
||||
|
||||
import numpy as np
|
||||
import cutlass.backend as pycutlass
|
||||
from cutlass.backend import *
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
import cutlass_bindings
|
||||
from bfloat16 import bfloat16
|
||||
|
||||
import argparse
|
||||
|
||||
@ -62,7 +67,7 @@ parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
parser.add_argument('-op', "--opcode", default="Simt", type=str,
|
||||
choices=["Simt", 'TensorOp'],
|
||||
help="This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM")
|
||||
@ -100,8 +105,6 @@ parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
parser.add_argument("-epv", "--epilogue_visitor", default=None,
|
||||
type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"],
|
||||
@ -147,12 +150,12 @@ pycutlass.compiler.nvcc()
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
element_a = getattr(cutlass_bindings, args.element_a)
|
||||
element_b = getattr(cutlass_bindings, args.element_b)
|
||||
element_c = getattr(cutlass_bindings, args.element_c)
|
||||
element_acc = getattr(cutlass_bindings, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
opclass = getattr(cutlass_bindings.OpClass, args.opcode)
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
@ -164,9 +167,9 @@ tile_description = TileDescription(
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
layout_a = getattr(cutlass_bindings, args.layout_a)
|
||||
layout_b = getattr(cutlass_bindings, args.layout_b)
|
||||
layout_c = getattr(cutlass_bindings, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
@ -180,7 +183,7 @@ C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
element_epilogue = getattr(cutlass_bindings, args.element_epilogue)
|
||||
if (args.activation_function == "identity"
|
||||
or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)):
|
||||
#
|
||||
@ -191,73 +194,12 @@ else:
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
|
||||
visitor = args.epilogue_visitor is not None
|
||||
|
||||
if args.epilogue_visitor == "ColumnReduction":
|
||||
class ColumnReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + beta * c
|
||||
reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0])
|
||||
return D, reduction
|
||||
epilogue_functor = ColumnReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
elif args.epilogue_visitor == "RowReduction":
|
||||
class RowReduction_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
D = alpha * accum + tanh.numpy(beta * c)
|
||||
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
|
||||
return D, reduction
|
||||
epilogue_functor = RowReduction_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
|
||||
elif args.epilogue_visitor == "RowBroadcast":
|
||||
class RowBroadcast_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
vector: 'row', alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = accum + vector
|
||||
scale_T = alpha * T
|
||||
Z = relu.numpy(scale_T + beta * c)
|
||||
return Z, T
|
||||
epilogue_functor = RowBroadcast_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
elif args.epilogue_visitor == "ColumnBroadcast":
|
||||
class ColumnBroadcast_(EpilogueVisitTree):
|
||||
def __call__(
|
||||
self, accum: 'tensor', c: 'tensor',
|
||||
vector: 'column', alpha: 'scalar', beta: 'scalar'):
|
||||
#
|
||||
T = accum + vector
|
||||
scale_T = leaky_relu.numpy(alpha * T, 0.2)
|
||||
Z = scale_T + beta * c
|
||||
return Z, T
|
||||
epilogue_functor = ColumnBroadcast_(
|
||||
epilogue_functor, tile_description, math_inst.element_accumulator,
|
||||
C.alignment, element_epilogue, C.element)
|
||||
epilogue_functor.initialize()
|
||||
else:
|
||||
epilogue_functor = epilogue_functor
|
||||
swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor)
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
|
||||
visitor=visitor
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
if args.print_cuda:
|
||||
@ -275,7 +217,7 @@ if args.gemm_mode == "GemmSplitKParallel":
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
|
||||
reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
|
||||
shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment),
|
||||
C=C, element_accumulator=element_acc,
|
||||
element_compute=element_epilogue,
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
@ -287,7 +229,7 @@ pycutlass.compiler.add_module(operations)
|
||||
|
||||
# User-provide inputs
|
||||
|
||||
problem_size = cutlass.gemm.GemmCoord(
|
||||
problem_size = cutlass_bindings.gemm.GemmCoord(
|
||||
args.problem_size[0], args.problem_size[1], args.problem_size[2])
|
||||
|
||||
tensor_a_size = args.batch * problem_size.m() * problem_size.k()
|
||||
@ -347,44 +289,13 @@ tensor_D = np.zeros(
|
||||
shape=(args.batch * problem_size.m() * problem_size.n(),)
|
||||
).astype(getattr(np, args.element_c))
|
||||
|
||||
if args.epilogue_visitor == "RowReduction":
|
||||
cta_n = args.threadblock_shape[1]
|
||||
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
|
||||
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c))
|
||||
output_op = operation.epilogue_type(
|
||||
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "ColumnReduction":
|
||||
cta_m = args.threadblock_shape[0]
|
||||
num_cta_m = (problem_size.m() + cta_m - 1) // cta_m
|
||||
reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c))
|
||||
output_op = operation.epilogue_type(
|
||||
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "RowBroadcast":
|
||||
vector = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n()))
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_t = np.empty_like(tensor_D)
|
||||
output_op = operation.epilogue_type(
|
||||
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
elif args.epilogue_visitor == "ColumnBroadcast":
|
||||
vector = np.ceil(
|
||||
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1))
|
||||
).astype(getattr(np, args.element_c))
|
||||
tensor_t = np.empty_like(tensor_D)
|
||||
output_op = operation.epilogue_type(
|
||||
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
|
||||
)
|
||||
else:
|
||||
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
|
||||
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
|
||||
|
||||
arguments = GemmArguments(
|
||||
operation=operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=output_op,
|
||||
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
|
||||
gemm_mode=getattr(cutlass_bindings.gemm.Mode, args.gemm_mode),
|
||||
split_k_slices=args.split_k_slices, batch=args.batch
|
||||
)
|
||||
|
||||
@ -411,38 +322,8 @@ reference = ReferenceModule(A, B, C)
|
||||
tensor_D_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch)
|
||||
|
||||
if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
|
||||
tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten()
|
||||
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
|
||||
|
||||
if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]:
|
||||
output_op.sync()
|
||||
accum_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
|
||||
tensor_D_ref, reduction_ref = epilogue_functor(
|
||||
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
args.alpha, args.beta
|
||||
)
|
||||
tensor_D_ref = tensor_D_ref.flatten()
|
||||
reduction_ref = reduction_ref.flatten()
|
||||
assert np.allclose(reduction_ref, reduction, atol=1e-2)
|
||||
|
||||
elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
|
||||
output_op.sync()
|
||||
accum_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
|
||||
|
||||
tensor_D_ref, tensor_T_ref = epilogue_functor(
|
||||
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
|
||||
vector, args.alpha, args.beta)
|
||||
|
||||
tensor_D_ref = tensor_D_ref.flatten()
|
||||
tensor_T_ref = tensor_T_ref.flatten()
|
||||
|
||||
assert np.array_equal(tensor_t, tensor_T_ref)
|
||||
|
||||
try:
|
||||
assert np.array_equal(tensor_D, tensor_D_ref)
|
||||
except:
|
||||
|
||||
@ -29,12 +29,17 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import numpy as np
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import csv
|
||||
|
||||
import sys
|
||||
print("This example is deprecated. Please see examples/python for examples of using "
|
||||
"the CUTLASS Python interface.")
|
||||
sys.exit(0)
|
||||
|
||||
import numpy as np
|
||||
import cutlass.backend as pycutlass
|
||||
from cutlass.backend import *
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
import csv
|
||||
|
||||
import argparse
|
||||
|
||||
@ -61,7 +66,7 @@ parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
parser.add_argument('-op', "--opcode", default="Simt", type=str,
|
||||
choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
|
||||
# tile description
|
||||
@ -111,7 +116,7 @@ parser.add_argument("-pm", "--precompute_mode",
|
||||
default="Device", type=str, choices=["Host", "Device"],
|
||||
help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)")
|
||||
# arguments
|
||||
parser.add_argument("-p", "--problem_size_dir", type=str,
|
||||
parser.add_argument("-p", "--problem_size_dir", type=str, default="grouped_gemm_problem_size.csv",
|
||||
help="path to the csv file contains the problem sizes")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
|
||||
@ -139,12 +144,12 @@ pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
element_a = getattr(cutlass_bindings, args.element_a)
|
||||
element_b = getattr(cutlass_bindings, args.element_b)
|
||||
element_c = getattr(cutlass_bindings, args.element_c)
|
||||
element_acc = getattr(cutlass_bindings, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
opclass = getattr(cutlass_bindings.OpClass, args.opcode)
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
@ -156,9 +161,9 @@ tile_description = TileDescription(
|
||||
math_inst
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
layout_a = getattr(cutlass_bindings, args.layout_a)
|
||||
layout_b = getattr(cutlass_bindings, args.layout_b)
|
||||
layout_c = getattr(cutlass_bindings, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
@ -172,7 +177,7 @@ C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
element_epilogue = getattr(cutlass_bindings, args.element_epilogue)
|
||||
if args.activation_function == "identity":
|
||||
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
@ -180,7 +185,7 @@ else:
|
||||
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
|
||||
getattr(pycutlass, args.activation_function)(element_epilogue),
|
||||
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor)
|
||||
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
|
||||
|
||||
operation = GemmOperationGrouped(
|
||||
@ -203,7 +208,7 @@ with open(args.problem_size_dir) as csv_file:
|
||||
reader = csv.reader(csv_file)
|
||||
for row in reader:
|
||||
problem_sizes.append(
|
||||
cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2]))
|
||||
cutlass_bindings.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2]))
|
||||
)
|
||||
|
||||
problem_count = len(problem_sizes)
|
||||
|
||||
@ -33,14 +33,18 @@
|
||||
Basic example of using the CUTLASS Python interface to run a GEMM
|
||||
"""
|
||||
|
||||
import sys
|
||||
print("This example is deprecated. Please see examples/python for examples of using "
|
||||
"the CUTLASS Python interface.")
|
||||
sys.exit(0)
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
import cutlass
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import cutlass_bindings
|
||||
import cutlass.backend as pycutlass
|
||||
from cutlass.backend import *
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'")
|
||||
@ -72,11 +76,11 @@ pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
pycutlass.compiler.nvcc()
|
||||
|
||||
# Set up A, B, C and accumulator
|
||||
A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment)
|
||||
B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment)
|
||||
C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment)
|
||||
element_acc = cutlass.float32
|
||||
element_epilogue = cutlass.float32
|
||||
A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment)
|
||||
B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment)
|
||||
C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment)
|
||||
element_acc = cutlass_bindings.float32
|
||||
element_epilogue = cutlass_bindings.float32
|
||||
|
||||
# Select instruction shape based on the Tensor Core instructions supported
|
||||
# by the device on which we are running
|
||||
@ -85,12 +89,14 @@ if cc == 70:
|
||||
elif cc == 75:
|
||||
instruction_shape = [16, 8, 8]
|
||||
else:
|
||||
# Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used)
|
||||
cc = 80
|
||||
instruction_shape = [16, 8, 16]
|
||||
|
||||
math_inst = MathInstruction(
|
||||
instruction_shape,
|
||||
A.element, B.element, element_acc,
|
||||
cutlass.OpClass.TensorOp,
|
||||
cutlass_bindings.OpClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
@ -122,7 +128,7 @@ tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.k * args.n,)
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.n,))).astype(np.float32)
|
||||
tensor_D = np.zeros(shape=(args.m * args.n,)).astype(np.float32)
|
||||
|
||||
problem_size = cutlass.gemm.GemmCoord(args.m, args.n, args.k)
|
||||
problem_size = cutlass_bindings.gemm.GemmCoord(args.m, args.n, args.k)
|
||||
alpha = 1.
|
||||
beta = 0.
|
||||
|
||||
|
||||
@ -33,14 +33,18 @@
|
||||
Basic example of using the CUTLASS Python interface to run a grouped GEMM
|
||||
"""
|
||||
|
||||
import sys
|
||||
print("This example is deprecated. Please see examples/python for examples of using "
|
||||
"the CUTLASS Python interface.")
|
||||
sys.exit(0)
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
import cutlass
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import cutlass_bindings
|
||||
import cutlass.backend as pycutlass
|
||||
from cutlass.backend import *
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python")
|
||||
@ -65,11 +69,11 @@ pycutlass.compiler.nvcc()
|
||||
|
||||
# Set up A, B, C and accumulator
|
||||
alignment = 1
|
||||
A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment)
|
||||
B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment)
|
||||
C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment)
|
||||
element_acc = cutlass.float32
|
||||
element_epilogue = cutlass.float32
|
||||
A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment)
|
||||
B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment)
|
||||
C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment)
|
||||
element_acc = cutlass_bindings.float32
|
||||
element_epilogue = cutlass_bindings.float32
|
||||
|
||||
# Select instruction shape based on the Tensor Core instructions supported
|
||||
# by the device on which we are running
|
||||
@ -78,12 +82,14 @@ if cc == 70:
|
||||
elif cc == 75:
|
||||
instruction_shape = [16, 8, 8]
|
||||
else:
|
||||
# Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used)
|
||||
cc = 80
|
||||
instruction_shape = [16, 8, 16]
|
||||
|
||||
math_inst = MathInstruction(
|
||||
instruction_shape,
|
||||
A.element, B.element, element_acc,
|
||||
cutlass.OpClass.TensorOp,
|
||||
cutlass_bindings.OpClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
@ -112,8 +118,8 @@ pycutlass.compiler.add_module(operations)
|
||||
|
||||
# Initialize tensors for each problem in the group
|
||||
problem_sizes = [
|
||||
cutlass.gemm.GemmCoord(128, 128, 64),
|
||||
cutlass.gemm.GemmCoord(512, 256, 128)
|
||||
cutlass_bindings.gemm.GemmCoord(128, 128, 64),
|
||||
cutlass_bindings.gemm.GemmCoord(512, 256, 128)
|
||||
]
|
||||
problem_count = len(problem_sizes)
|
||||
|
||||
|
||||
@ -37,8 +37,20 @@ cutlass_example_add_executable(
|
||||
fused_multihead_attention_variable_seqlen.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
41_fused_multi_head_attention_backward
|
||||
fused_multi_head_attention_backward.cu
|
||||
DISABLE_TESTS ON
|
||||
)
|
||||
|
||||
|
||||
add_custom_target(41_fused_multi_head_attention
|
||||
DEPENDS 41_fused_multi_head_attention_fixed_seqlen
|
||||
41_fused_multi_head_attention_variable_seqlen
|
||||
41_fused_multi_head_attention_backward
|
||||
)
|
||||
|
||||
add_test(
|
||||
NAME ctest_examples_41_fmha_backward_python
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/fmha_backward_test.py $<TARGET_FILE:41_fused_multi_head_attention_backward>
|
||||
)
|
||||
|
||||
@ -30,10 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
@ -50,6 +50,7 @@
|
||||
|
||||
#include "fmha_grouped.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "gemm/custom_mma.h"
|
||||
#include "gemm/find_default_mma.h"
|
||||
#include "gemm/mma_from_smem.h"
|
||||
|
||||
@ -70,7 +71,7 @@ template <
|
||||
bool isAligned_,
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration,
|
||||
int kMaxK = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
|
||||
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly
|
||||
>
|
||||
struct DefaultFMHAGrouped {
|
||||
@ -85,6 +86,8 @@ struct DefaultFMHAGrouped {
|
||||
|
||||
using ArchTag = ArchTag_;
|
||||
static bool const kIsAligned = isAligned_;
|
||||
static bool const kSingleValueIteration = kMaxK <= kKeysPerBlock;
|
||||
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static int const kWarpSize = 32;
|
||||
static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize);
|
||||
|
||||
@ -145,14 +148,20 @@ struct DefaultFMHAGrouped {
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||
? 4
|
||||
: DefaultConfig::kStages,
|
||||
Operator
|
||||
>::DefaultMma;
|
||||
|
||||
using MmaCore = typename DefaultMma::MmaCore;
|
||||
using IteratorA = typename DefaultMma::IteratorA;
|
||||
using IteratorB = typename DefaultMma::IteratorB;
|
||||
using Mma = typename DefaultMma::ThreadblockMma;
|
||||
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
|
||||
using Mma = typename cutlass::platform::conditional<
|
||||
kSingleValueIteration,
|
||||
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
|
||||
DefaultThreadblockMma>::type;
|
||||
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
||||
typename Mma::Operator::IteratorC,
|
||||
ElementAccumulator,
|
||||
@ -232,14 +241,24 @@ struct DefaultFMHAGrouped {
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||
? 4
|
||||
: DefaultConfig::kStages,
|
||||
kSplitKSerial,
|
||||
Operator>;
|
||||
|
||||
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
||||
DefaultWarpIteratorAFromSharedMemory<
|
||||
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
|
||||
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
|
||||
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
|
||||
typename DefaultGemm::Mma::Policy>::WarpIterator;
|
||||
|
||||
using DefaultMmaFromSmem =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||
typename DefaultGemm::Mma,
|
||||
typename MM0::AccumulatorSharedStorage,
|
||||
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
|
||||
WarpIteratorA,
|
||||
false>; // kScaleOperandA
|
||||
|
||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||
@ -256,10 +275,6 @@ struct DefaultFMHAGrouped {
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_accum_t>;
|
||||
|
||||
struct SharedStorageMM1 {
|
||||
typename Mma::SharedStorage mm;
|
||||
};
|
||||
};
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
|
||||
200
examples/41_fused_multi_head_attention/fmha_backward_test.py
Normal file
200
examples/41_fused_multi_head_attention/fmha_backward_test.py
Normal file
@ -0,0 +1,200 @@
|
||||
import argparse
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
from piped_subprocess import PipedSubprocess, TORCH_DTYPE_NAME
|
||||
import math
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("example_exe", type=str, help="Path to the 41_fused_multi_head_attention_backward executable")
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(0)
|
||||
dtype = torch.float16
|
||||
B, Mq, Mkv, H, K, Kv = 2, 1024, 1024, 5, 128, 128
|
||||
causal = True
|
||||
repeat_count = 100
|
||||
|
||||
ATOL = {
|
||||
torch.float: 5e-4,
|
||||
torch.half: 9.5e-2,
|
||||
torch.bfloat16: 7e-1,
|
||||
}[dtype]
|
||||
|
||||
RTOL = {
|
||||
torch.float: 1e-4,
|
||||
torch.half: 2e-2,
|
||||
torch.bfloat16: 1e-1,
|
||||
}[dtype]
|
||||
|
||||
|
||||
assert not (causal and Mq < Mkv), "causal only supports seqlenK <= seqlenQ"
|
||||
|
||||
fmha_bw_binary = args.example_exe
|
||||
if not os.path.isfile(fmha_bw_binary):
|
||||
print(f"""No such file: `{fmha_bw_binary}`\nDid you forget to run "make 41_fused_multi_head_attention"?""")
|
||||
sys.exit(1)
|
||||
|
||||
def create_lower_triangular_mask():
|
||||
return torch.triu(torch.full( # type: ignore
|
||||
[1, Mq, Mkv],
|
||||
dtype=dtype,
|
||||
fill_value=float("-inf"),
|
||||
), diagonal=1)
|
||||
|
||||
def ref_mha_bmk(q, k, v, mask):
|
||||
# Multi-head attention with inputs/outputs in BMK format
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
v = v.float()
|
||||
|
||||
q = q * (1 / q.shape[-1] ** 0.5)
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
if mask is not None:
|
||||
attn += mask
|
||||
attn_max = attn.max(-1, True).values
|
||||
attn_norm = (attn - attn_max).exp().sum(-1, True)
|
||||
attn = attn.softmax(-1)
|
||||
lse = attn_max + attn_norm.log()
|
||||
lse = lse.squeeze(2)
|
||||
return attn @ v, lse
|
||||
|
||||
|
||||
def bmhk2bmk(t):
|
||||
return t.permute((0, 2, 1, 3)).reshape(
|
||||
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
|
||||
)
|
||||
|
||||
def ref_mha_bmhk(q, k, v, mask):
|
||||
# Multi-head attention with inputs/outputs in BMHK format
|
||||
assert q.ndim == 4
|
||||
|
||||
out, lse = ref_mha_bmk(bmhk2bmk(q), bmhk2bmk(k), bmhk2bmk(v), mask=mask)
|
||||
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
|
||||
return out.permute((0, 2, 1, 3)), lse.reshape([q.shape[0], q.shape[2], q.shape[1]])
|
||||
|
||||
def ref_mha_bw_bmhk(q, k, v, mask, lse, out, grad_out, delta):
|
||||
lse = lse[:, :, :q.shape[1]] #BMH, unpad Q dimension
|
||||
delta = delta.reshape([-1, delta.shape[-1], 1])
|
||||
|
||||
# bmhk -> bmk
|
||||
q, k, v, out, grad_out = [bmhk2bmk(x).float() for x in (q, k, v, out, grad_out)]
|
||||
|
||||
attn_T = k @ q.transpose(-2, -1)
|
||||
if mask is not None:
|
||||
attn_T += mask.transpose(-2, -1)
|
||||
attn_T = attn_T * (1 / q.shape[-1] ** 0.5)
|
||||
attn_T = attn_T - lse.reshape([-1, 1, lse.shape[-1]])
|
||||
attn_T = attn_T.exp()
|
||||
|
||||
grad_v = attn_T @ grad_out
|
||||
|
||||
dov = grad_out @ v.transpose(-2, -1)
|
||||
tmp = (dov - delta) * attn_T.transpose(-2, -1)
|
||||
tmp = tmp / (q.shape[-1] ** 0.5)
|
||||
|
||||
grad_q = tmp @ k
|
||||
grad_k = tmp.transpose(-2, -1) @ q
|
||||
|
||||
return [x.reshape([B, H, x.shape[1], x.shape[-1]]).permute([0, 2, 1, 3]) for x in [grad_q, grad_k, grad_v]]
|
||||
|
||||
|
||||
print("initializing tensors...")
|
||||
query = torch.randn([B, Mq, H, K], dtype=dtype)
|
||||
key = 3 * torch.randn([B, Mkv, H, K], dtype=dtype)
|
||||
value = 3 * torch.randn([B, Mkv, H, Kv], dtype=dtype)
|
||||
mask = create_lower_triangular_mask() if causal else None
|
||||
|
||||
# let PyTorch compute gradients
|
||||
query.requires_grad_(True)
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
print("computing fw...")
|
||||
out, lse = ref_mha_bmhk(query, key, value, mask=mask)
|
||||
out = out.to(dtype).contiguous()
|
||||
grad_out = 3 * torch.randn([B, Mq, H, Kv], dtype=dtype)
|
||||
|
||||
print("computing bw with autograd...")
|
||||
out.backward(grad_out)
|
||||
scale = (1 / query.shape[-1] ** 0.5)
|
||||
|
||||
|
||||
# Additional data needed by the kernel
|
||||
delta = (grad_out.float() * out.float()).sum(-1).transpose(-2, -1).contiguous()
|
||||
pad_amount = (32 - (lse.shape[2] % 32)) % 32
|
||||
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
|
||||
|
||||
print("computing bw with reference implem...")
|
||||
gQr, gKr, gVr = ref_mha_bw_bmhk(query, key, value, mask, lse, out, grad_out, delta)
|
||||
|
||||
with PipedSubprocess(fmha_bw_binary) as bw_kernel:
|
||||
# Send kernel arguments
|
||||
bw_kernel.write(
|
||||
TORCH_DTYPE_NAME[query.dtype],
|
||||
"scale", scale,
|
||||
"head_dim", K,
|
||||
"head_dim_value", Kv,
|
||||
"num_queries", Mq,
|
||||
"num_keys", Mkv,
|
||||
"num_heads", H,
|
||||
"custom_mask_type", (1 if causal else 0),
|
||||
"num_batches", B,
|
||||
"repeat_count", repeat_count,
|
||||
"num_splits_key", (Mkv // 128),
|
||||
)
|
||||
bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"])
|
||||
bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"])
|
||||
bw_kernel.writeTensor(value, "value", ["v_strideB", "v_strideM", "v_strideH"])
|
||||
bw_kernel.writeTensor(lse, "logsumexp", ["lse_strideB", "lse_strideH"])
|
||||
bw_kernel.writeTensor(out, "output", ["o_strideB", "o_strideM", "o_strideH"])
|
||||
bw_kernel.writeTensor(grad_out, "grad_output", ["gO_strideB", "gO_strideM", "gO_strideH"])
|
||||
bw_kernel.writeTensor(delta, "delta", ["delta_strideB", "delta_strideH"])
|
||||
|
||||
if bw_kernel.read() != "OK":
|
||||
print("Got unexpected output")
|
||||
print(bw_kernel.subp.communicate()[0])
|
||||
sys.exit(0)
|
||||
|
||||
# Read kernel output
|
||||
gQ = bw_kernel.readTensor("grad_query", ["gQ_strideB", "gQ_strideM", "gQ_strideH"], query.shape).float()
|
||||
gK = bw_kernel.readTensor("grad_key", ["gK_strideB", "gK_strideM", "gK_strideH"], key.shape).float()
|
||||
gV = bw_kernel.readTensor("grad_value", ["gV_strideB", "gV_strideM", "gV_strideH"], value.shape).float()
|
||||
runtime_ms = float(bw_kernel.readNamed("runtime_ms"))
|
||||
|
||||
float_ops = B * H * sum([
|
||||
# att = Q @ K.transpose
|
||||
Mq * Mkv * K * 2,
|
||||
# att @ dO
|
||||
Mkv * Mq * Kv * 2,
|
||||
# dov = dO @ V
|
||||
Mq * Kv * Mkv * 2,
|
||||
# dov @ K
|
||||
Mq * K * Mkv * 2,
|
||||
# dov @ Q
|
||||
Mq * K * Mkv * 2,
|
||||
])
|
||||
if causal:
|
||||
float_ops //= 2
|
||||
|
||||
print(f"""
|
||||
Fused multi-head attention - backward
|
||||
batch_size={B}
|
||||
num_queries={Mq}
|
||||
num_keys={Mkv}
|
||||
num_heads={H}
|
||||
head_dim={K}
|
||||
head_dim_value={Kv}
|
||||
|
||||
Correctness:
|
||||
grad_query: {"PASS" if torch.allclose(gQ, gQr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gQ - gQr).abs().max()})
|
||||
grad_key: {"PASS" if torch.allclose(gK, gKr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gK - gKr).abs().max()})
|
||||
grad_value: {"PASS" if torch.allclose(gV, gVr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gV - gVr).abs().max()})
|
||||
(atol={ATOL} / rtol={RTOL})
|
||||
Runtime: {runtime_ms}ms ({(float_ops / (1024 ** 4)) / (runtime_ms / 1000):.4f} TFlops)
|
||||
""")
|
||||
|
||||
assert torch.allclose(query.grad.float(), gQr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
|
||||
assert torch.allclose(key.grad.float(), gKr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
|
||||
assert torch.allclose(value.grad.float(), gVr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
|
||||
@ -147,6 +147,9 @@ public:
|
||||
static int const kThreadsPerWarp = 32;
|
||||
static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
|
||||
|
||||
static constexpr int kNumWarpsPerBlock =
|
||||
kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp);
|
||||
|
||||
using ProblemVisitor = FMHAGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
kGroupScheduleMode,
|
||||
@ -369,13 +372,16 @@ public:
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> out_rescale;
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
|
||||
addition_storage;
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::Mma::SharedStorage mm1;
|
||||
};
|
||||
|
||||
union {
|
||||
@ -397,7 +403,7 @@ public:
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::Mma::SharedStorage mm1;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
@ -490,6 +496,7 @@ public:
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
|
||||
auto& mi = shared_storage.mi;
|
||||
auto& out_rescale = shared_storage.out_rescale;
|
||||
|
||||
ProblemVisitor problem_visitor(
|
||||
params.problem_visitor,
|
||||
@ -512,6 +519,7 @@ public:
|
||||
|
||||
if (thread_id() < kQueriesPerBlock) {
|
||||
s_prime[thread_id()] = ElementAccumulator(0);
|
||||
out_rescale[thread_id()] = accum_t(1.0);
|
||||
m_prime[thread_id()] =
|
||||
-cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
|
||||
mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
|
||||
@ -568,7 +576,7 @@ public:
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
|
||||
MM1::Mma::prologue(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.mm1,
|
||||
iterator_V,
|
||||
thread_id(),
|
||||
problem_size_1_k);
|
||||
@ -623,6 +631,8 @@ public:
|
||||
|
||||
if (kPreloadV) {
|
||||
prologueV(0);
|
||||
} else {
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
}
|
||||
|
||||
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
||||
@ -649,30 +659,48 @@ public:
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
}
|
||||
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
num_keys - iter_key_start >= kKeysPerBlock,
|
||||
kFullColumns,
|
||||
([&] {
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also does accum[i] <- exp(accum[i] - mi)
|
||||
iterative_softmax<
|
||||
typename MM0::Mma::Operator::IteratorC,
|
||||
kFullColumns,
|
||||
kIsFirst>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
lane_id(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
kSupportsBias ? 1.0f : params.scale);
|
||||
}));
|
||||
}));
|
||||
// DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||
// DISPATCH_BOOL(
|
||||
// num_keys - iter_key_start >= kKeysPerBlock,
|
||||
// kFullColumns,
|
||||
// ([&] {
|
||||
// // Update `mi` from accum stored in registers
|
||||
// // Also does accum[i] <- exp(accum[i] - mi)
|
||||
// iterative_softmax<
|
||||
// typename MM0::Mma::Operator::IteratorC,
|
||||
// kFullColumns,
|
||||
// kIsFirst>(
|
||||
// accum_o,
|
||||
// accum,
|
||||
// mi,
|
||||
// m_prime,
|
||||
// s_prime,
|
||||
// lane_id(),
|
||||
// thread_id(),
|
||||
// warp_id(),
|
||||
// num_keys - iter_key_start,
|
||||
// iteratorC_tile_offset,
|
||||
// kSupportsBias ? 1.0f : params.scale);
|
||||
// }));
|
||||
// }));
|
||||
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also does accum[i] <- exp(accum[i] - mi)
|
||||
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
out_rescale,
|
||||
shared_storage.addition_storage,
|
||||
lane_id(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
num_keys - iter_key_start,
|
||||
iter_key_start == 0,
|
||||
iteratorC_tile_offset,
|
||||
kSupportsBias ? 1.0f : params.scale);
|
||||
|
||||
// Output results to shared-memory
|
||||
int warp_idx_mn_0 = warp_id() %
|
||||
@ -717,12 +745,14 @@ public:
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
|
||||
typename MM1::Mma mma_pv(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.si,
|
||||
// operand A: Pij_dropped in shared memory
|
||||
shared_storage.after_mm0.si.accum_ref(),
|
||||
// operand B: shared memory staging area for Vj, which is loaded
|
||||
// from global memory
|
||||
shared_storage.after_mm0.mm1.operand_B_ref(),
|
||||
(int)thread_id(),
|
||||
(int)warp_id(),
|
||||
(int)lane_id(),
|
||||
(int)problem_size_1_k);
|
||||
(int)lane_id());
|
||||
|
||||
mma_pv.set_prologue_done(kPreloadV);
|
||||
if (!kKeepOutputInRF) {
|
||||
@ -737,6 +767,7 @@ public:
|
||||
}
|
||||
|
||||
if (!kKeepOutputInRF) {
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
DISPATCH_BOOL(
|
||||
iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
@ -787,7 +818,7 @@ public:
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
@ -836,34 +867,37 @@ public:
|
||||
typename MM1::OutputTileIteratorAccum // source tile
|
||||
>;
|
||||
auto dest_iter = createOutputIter(0);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
epilogue(rescale, dest_iter, accum_o);
|
||||
}
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
__syncthreads(); // Don't start the next iteration until all threads are done using shared memory.
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename WarpIteratorC,
|
||||
bool kFullColumns,
|
||||
bool kIsFirst>
|
||||
template <typename WarpIteratorC>
|
||||
CUTLASS_DEVICE static void iterative_softmax(
|
||||
typename WarpIteratorC::Fragment& frag_o, // output so far
|
||||
typename WarpIteratorC::Fragment& frag,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
|
||||
addition_storage,
|
||||
int8_t lane_id,
|
||||
int8_t thread_id,
|
||||
int8_t warp_id,
|
||||
int16_t max_col,
|
||||
int max_col,
|
||||
bool is_first,
|
||||
typename WarpIteratorC::TensorCoord const& tile_offset,
|
||||
float scaling) {
|
||||
/* Iterates on the accumulator and corresponding position on result matrix
|
||||
@ -884,12 +918,11 @@ public:
|
||||
kThreadsPerWarp>::Iterator;
|
||||
// Convert to `accum_t` (rather than double)
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (!kIsFirst) {
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
m_prime[thread_id] = mi[thread_id];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
|
||||
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
|
||||
|
||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
auto lane_offset =
|
||||
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
||||
@ -903,46 +936,64 @@ public:
|
||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (kFullColumns || accum_n < max_col) {
|
||||
if (accum_n < max_col) {
|
||||
max = cutlass::fast_max(max, frag[idx]);
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {
|
||||
// Having 4x atomicMax seems faster than reduce within warp
|
||||
// first...
|
||||
atomicMaxFloat(&mi[accum_m], max * scaling);
|
||||
atomicMaxFloat(&mi[accum_m], max);
|
||||
});
|
||||
}
|
||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
// Make sure we all share the update values for `mi`
|
||||
__syncthreads();
|
||||
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
|
||||
m_prime[thread_id] = m_prime_exp;
|
||||
s_prime[thread_id] *= m_prime_exp;
|
||||
// Doing this `exp` is quite expensive. Let's
|
||||
// split it across the warps
|
||||
bool restore_mi_to_minus_inf = false;
|
||||
if (lane_id < kLinesPerWarp) {
|
||||
int id = warp_id * kLinesPerWarp + lane_id;
|
||||
auto m_prime_id = m_prime[id];
|
||||
auto mi_id = mi[id];
|
||||
bool changed = m_prime_id < mi_id; // `false` if both are -inf
|
||||
if (changed) {
|
||||
auto m_prime_exp = exp2f(m_prime_id - mi_id);
|
||||
out_rescale[id] = m_prime_exp;
|
||||
s_prime[id] *= m_prime_exp;
|
||||
} else {
|
||||
// Only when bias is enabled, it's possible that all the first values
|
||||
// of attention are masked to `-inf`. In that case we want to avoid
|
||||
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
|
||||
if (kSupportsBias &&
|
||||
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
|
||||
restore_mi_to_minus_inf = true;
|
||||
mi[id] = 0.0f;
|
||||
}
|
||||
out_rescale[id] = 1.0f;
|
||||
}
|
||||
}
|
||||
__syncthreads(); // Update output fragments
|
||||
if (kKeepOutputInRF && !kIsFirst) {
|
||||
accum_t mp;
|
||||
if (kKeepOutputInRF && !is_first) {
|
||||
accum_t line_rescale;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mp = m_prime[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
|
||||
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag_o[idx] = frag_o[idx] * line_rescale;
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
__syncthreads();
|
||||
}
|
||||
// Update accum_m, accum_n, ...
|
||||
{
|
||||
accum_t mi_row, total_row;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
|
||||
[&](int accum_m) { mi_row = mi[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag[idx] = (kFullColumns || accum_n < max_col)
|
||||
? exp2f(frag[idx] - mi_row)
|
||||
: accum_t(0.0);
|
||||
frag[idx] =
|
||||
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
LambdaIterator::iterateRows(
|
||||
@ -954,10 +1005,31 @@ public:
|
||||
lane_id, total_row, [](accum_t a, accum_t b) {
|
||||
return a + b;
|
||||
})) {
|
||||
atomicAdd(&s_prime[accum_m], total_row);
|
||||
// NOTE: we could atomically add `total_row` to `s_prime`, but
|
||||
// it's faster (and deterministic) to avoid atomics here
|
||||
addition_storage
|
||||
[accum_m + kQueriesPerBlock * tile_offset.column()] =
|
||||
total_row;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (lane_id < kLinesPerWarp) {
|
||||
int id = warp_id * kLinesPerWarp + lane_id;
|
||||
accum_t total_row = s_prime[id];
|
||||
if (restore_mi_to_minus_inf) {
|
||||
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
|
||||
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
} else {
|
||||
m_prime[id] = mi[id];
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
|
||||
total_row += addition_storage[id + kQueriesPerBlock * i];
|
||||
}
|
||||
s_prime[id] = total_row;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -0,0 +1,298 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "kernel_backward.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
|
||||
using Arch = cutlass::arch::Sm80;
|
||||
static constexpr int kMaxK = 128;
|
||||
|
||||
template <typename ArchTag, typename Element, int kMaxK>
|
||||
struct DefaultKernel {
|
||||
// Some heuristics to select the best kernel (tested on Sm60, Sm70, Sm80)
|
||||
// NOTE: Requires quite a lot of shmem for Sm80+,
|
||||
// so might require tweaking those manually for Sm86/Sm89
|
||||
|
||||
static constexpr bool kSupports64x128 =
|
||||
ArchTag::kMinComputeCapability >= 80 ||
|
||||
(ArchTag::kMinComputeCapability >= 70 &&
|
||||
cutlass::sizeof_bits<Element>::value <= 16);
|
||||
static constexpr int kBlockSizeI = kSupports64x128 && kMaxK > 64 ? 128 : 64;
|
||||
static constexpr bool kIsHalf = cutlass::sizeof_bits<Element>::value <= 16;
|
||||
static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
|
||||
static constexpr bool kPreload = kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF;
|
||||
static constexpr int kBlockSizeJ = kPreload && kMaxK > 64 ? 128 : 64;
|
||||
|
||||
using Kernel = AttentionBackwardKernel<
|
||||
Arch,
|
||||
Element,
|
||||
true, // kIsAligned_
|
||||
false, // kApplyDropout_
|
||||
kPreload, // kPreload_
|
||||
kBlockSizeI, // kBlockSizeI_,
|
||||
kBlockSizeJ, // kBlockSizeJ_,
|
||||
kMaxK, // kMaxK
|
||||
false, // kKeysQueriesAlignedToBlockSize
|
||||
true // kEnableSplitKeys
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
template <typename T> struct TypeName;
|
||||
template <> struct TypeName<float> { static constexpr const char* Name = "f32"; };
|
||||
template <> struct TypeName<cutlass::half_t> { static constexpr const char* Name = "f16"; };
|
||||
template <> struct TypeName<cutlass::bfloat16_t> { static constexpr const char* Name = "b16"; };
|
||||
|
||||
void readExpect(std::string const& expected) {
|
||||
std::string read;
|
||||
std::cin >> read;
|
||||
if (read != expected) {
|
||||
std::cerr << "FATAL: Read '" << read << "' but expected '" << expected << "'" << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helpers to read from stdin
|
||||
template <typename Element>
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> readTensorOnDevice(std::string const& expectedName) {
|
||||
readExpect("tensor_begin");
|
||||
readExpect(std::string(TypeName<Element>::Name) + ":" + expectedName);
|
||||
uint64_t len = 0;
|
||||
std::cin >> len;
|
||||
readExpect("file");
|
||||
std::string filename;
|
||||
std::cin >> filename;
|
||||
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor({int64_t(1), int64_t(len / sizeof(Element))});
|
||||
uint8_t* data = (uint8_t*)tensor.host_data();
|
||||
|
||||
std::fstream myFile(filename, std::ios::in | std::ios::binary );
|
||||
myFile.read((char*)data, len);
|
||||
readExpect("tensor_end");
|
||||
tensor.sync_device();
|
||||
return tensor;
|
||||
}
|
||||
|
||||
int64_t readInt64(std::string const& expectedName) {
|
||||
readExpect(expectedName);
|
||||
int64_t s = 0;
|
||||
std::cin >> s;
|
||||
return s;
|
||||
}
|
||||
|
||||
float readFloat(std::string const& expectedName) {
|
||||
readExpect(expectedName);
|
||||
float s = 0;
|
||||
std::cin >> s;
|
||||
return s;
|
||||
}
|
||||
|
||||
// Writing
|
||||
template <typename Element>
|
||||
void writeTensor(std::string const& name, cutlass::HostTensor<Element, cutlass::layout::RowMajor>& tensor) {
|
||||
tensor.sync_host(); // device->host
|
||||
size_t u8len = tensor.size() * sizeof(Element);
|
||||
|
||||
// Python is expected to provide a file name to write to
|
||||
readExpect("tmpfile");
|
||||
std::string tmpfile;
|
||||
std::cin >> tmpfile;
|
||||
|
||||
uint8_t* data = (uint8_t*)tensor.host_data();
|
||||
std::fstream myFile(tmpfile, std::ios::out | std::ios::binary );
|
||||
myFile.write((char*)data, u8len);
|
||||
myFile.close();
|
||||
|
||||
std::cout << "tensor_begin " << TypeName<Element>::Name << ":" << name << " ";
|
||||
std::cout << u8len << " file " << tmpfile << " tensor_end" << std::endl;
|
||||
}
|
||||
|
||||
void writeInt64(std::string const& name, int64_t value) {
|
||||
std::cout << name << " " << value << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element>
|
||||
int runKernel() {
|
||||
using Kernel = typename DefaultKernel<Arch, Element, kMaxK>::Kernel;
|
||||
|
||||
#define READ_I64(NAME) p.NAME = (decltype(p.NAME))readInt64(#NAME)
|
||||
#define READ_TENSOR_AND_STRIDES_BMH(DT, NAME, NAME_XS) \
|
||||
auto storage##NAME = readTensorOnDevice<DT>(#NAME); \
|
||||
p.NAME##_ptr = storage##NAME.device_data(); \
|
||||
READ_I64(NAME_XS##_strideB); \
|
||||
READ_I64(NAME_XS##_strideM); \
|
||||
READ_I64(NAME_XS##_strideH);
|
||||
|
||||
#define CUDA_CHECK(FN) { \
|
||||
auto cudaError = FN; \
|
||||
if (cudaError != cudaSuccess) { \
|
||||
std::cerr << "FATAL: " #FN " failed: " << cudaGetErrorString(cudaError) << std::endl; \
|
||||
return -1; \
|
||||
} \
|
||||
}
|
||||
|
||||
typename Kernel::Params p;
|
||||
p.scale = readFloat("scale");
|
||||
READ_I64(head_dim);
|
||||
READ_I64(head_dim_value);
|
||||
READ_I64(num_queries);
|
||||
READ_I64(num_keys);
|
||||
READ_I64(num_heads);
|
||||
READ_I64(custom_mask_type);
|
||||
READ_I64(num_batches);
|
||||
int64_t repeat_count = readInt64("repeat_count");
|
||||
READ_I64(num_splits_key);
|
||||
|
||||
READ_TENSOR_AND_STRIDES_BMH(Element, query, q);
|
||||
READ_TENSOR_AND_STRIDES_BMH(Element, key, k);
|
||||
READ_TENSOR_AND_STRIDES_BMH(Element, value, v);
|
||||
auto lse = readTensorOnDevice<typename Kernel::lse_scalar_t>("logsumexp");
|
||||
p.logsumexp_ptr = lse.device_data();
|
||||
p.lse_strideB = readInt64("lse_strideB");
|
||||
p.lse_strideH = readInt64("lse_strideH");
|
||||
|
||||
// output
|
||||
auto stOutput = readTensorOnDevice<Element>("output");
|
||||
p.output_ptr = stOutput.device_data();
|
||||
READ_I64(o_strideB);
|
||||
auto o_strideM = readInt64("o_strideM");
|
||||
if (o_strideM != p.o_strideM()) {
|
||||
std::cerr << "Invalid `o_strideM`: " << o_strideM << " - expected " << p.o_strideM();
|
||||
return 2;
|
||||
}
|
||||
READ_I64(o_strideH);
|
||||
|
||||
READ_TENSOR_AND_STRIDES_BMH(Element, grad_output, gO);
|
||||
|
||||
auto stDelta = readTensorOnDevice<typename Kernel::accum_t>("delta");
|
||||
p.delta_ptr = stDelta.device_data();
|
||||
READ_I64(delta_strideB);
|
||||
READ_I64(delta_strideH);
|
||||
|
||||
// Allocate workspace
|
||||
if (p.workspace_size()) {
|
||||
cudaMalloc(&p.workspace, p.workspace_size());
|
||||
}
|
||||
|
||||
// Allocate outputs in BMHK format
|
||||
p.gQKV_strideM_multiplier = 1;
|
||||
p.gQ_strideH = p.head_dim;
|
||||
p.gQ_strideB = p.gQ_strideM() * p.num_queries;
|
||||
p.gK_strideH = p.head_dim;
|
||||
p.gK_strideB = p.gK_strideM() * p.num_keys;
|
||||
p.gV_strideH = p.head_dim_value;
|
||||
p.gV_strideB = p.gV_strideM() * p.num_keys;
|
||||
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gQ({int64_t(1), p.gQ_strideB * p.num_batches});
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gK({int64_t(1), p.gK_strideB * p.num_batches});
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gV({int64_t(1), p.gV_strideB * p.num_batches});
|
||||
p.grad_query_ptr = gQ.device_data();
|
||||
p.grad_key_ptr = gK.device_data();
|
||||
p.grad_value_ptr = gV.device_data();
|
||||
|
||||
if (!Kernel::check_supported(p)) {
|
||||
std::cerr << "FATAL: Kernel does not support these inputs" << std::endl;
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Run kernel
|
||||
cudaDeviceSynchronize();
|
||||
auto kernel_fn = attention_kernel_backward_batched_impl<Kernel>;
|
||||
size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
|
||||
CUDA_CHECK(cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes)));
|
||||
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
|
||||
|
||||
// Write outputs
|
||||
std::cout << "OK ";
|
||||
writeTensor("grad_query", gQ);
|
||||
writeInt64("gQ_strideB", p.gQ_strideB);
|
||||
writeInt64("gQ_strideM", p.gQ_strideM());
|
||||
writeInt64("gQ_strideH", p.gQ_strideH);
|
||||
writeTensor("grad_key", gK);
|
||||
writeInt64("gK_strideB", p.gK_strideB);
|
||||
writeInt64("gK_strideM", p.gK_strideM());
|
||||
writeInt64("gK_strideH", p.gK_strideH);
|
||||
writeTensor("grad_value", gV);
|
||||
writeInt64("gV_strideB", p.gV_strideB);
|
||||
writeInt64("gV_strideM", p.gV_strideM());
|
||||
writeInt64("gV_strideH", p.gV_strideH);
|
||||
|
||||
// Timing
|
||||
cudaEvent_t events[2];
|
||||
for (auto & event : events) {
|
||||
CUDA_CHECK(cudaEventCreate(&event));
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(events[0]));
|
||||
for (int i = 0; i < repeat_count; ++i) {
|
||||
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(events[1]));
|
||||
CUDA_CHECK(cudaEventSynchronize(events[1]));
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
CUDA_CHECK(cudaEventElapsedTime(&runtime_ms, events[0], events[1]));
|
||||
|
||||
std::cout << "runtime_ms " << runtime_ms / float(repeat_count) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::ios_base::sync_with_stdio(false);
|
||||
|
||||
std::string dtype;
|
||||
std::cin >> dtype;
|
||||
std::cerr << "Running kernel with dtype: " << dtype << std::endl;
|
||||
if (dtype == "f16") {
|
||||
return runKernel<cutlass::half_t>();
|
||||
} else if (dtype == "b16") {
|
||||
return runKernel<cutlass::bfloat16_t>();
|
||||
} else if (dtype == "f32") {
|
||||
return runKernel<float>();
|
||||
} else {
|
||||
std::cerr << "FATAL: Unknown dtype: " << dtype << std::endl;
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -999,7 +999,7 @@ public:
|
||||
template <
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration
|
||||
int kMaxK
|
||||
>
|
||||
int run_attention(Options& options) {
|
||||
using Attention = AttentionKernel<
|
||||
@ -1008,7 +1008,7 @@ int run_attention(Options& options) {
|
||||
true, // Memory is aligned
|
||||
kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration,
|
||||
kMaxK,
|
||||
false, // Supports dropout
|
||||
false // Supports bias
|
||||
>;
|
||||
@ -1094,15 +1094,16 @@ int main(int argc, char const **args) {
|
||||
if (options.head_size_v > 64) {
|
||||
static int const kQueriesPerBlock = 32;
|
||||
static int const kKeysPerBlock = 128;
|
||||
if (options.head_size_v <= kKeysPerBlock) {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
||||
if (options.head_size_v <= 128) {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
|
||||
} else {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
|
||||
}
|
||||
} else {
|
||||
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
|
||||
static int const kQueriesPerBlock = 64;
|
||||
static int const kKeysPerBlock = 64;
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1061,7 +1061,7 @@ public:
|
||||
template <
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration,
|
||||
int kMaxK,
|
||||
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_
|
||||
>
|
||||
int run_grouped(Options& options) {
|
||||
@ -1071,7 +1071,7 @@ int run_grouped(Options& options) {
|
||||
true, // Memory is aligned
|
||||
kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration,
|
||||
kMaxK,
|
||||
GroupScheduleMode_
|
||||
>::FMHAKernel;
|
||||
|
||||
@ -1098,18 +1098,18 @@ int run_grouped(Options& options) {
|
||||
template <
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration
|
||||
int kMaxK
|
||||
>
|
||||
int run_attention(Options& options) {
|
||||
if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) {
|
||||
return run_grouped<kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration,
|
||||
kMaxK,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>(options);
|
||||
} else {
|
||||
return run_grouped<kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration,
|
||||
kMaxK,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>(options);
|
||||
}
|
||||
}
|
||||
@ -1180,14 +1180,15 @@ int main(int argc, char const **args) {
|
||||
static int const kQueriesPerBlock = 32;
|
||||
static int const kKeysPerBlock = 128;
|
||||
if (options.head_size_v <= kKeysPerBlock) {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
|
||||
} else {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
|
||||
}
|
||||
} else {
|
||||
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
|
||||
static int const kQueriesPerBlock = 64;
|
||||
static int const kKeysPerBlock = 64;
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -747,14 +747,6 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
accum = plus_accum(accum, tmp_accum);
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM
|
||||
// mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -310,7 +310,8 @@ class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
||||
// issuing shared memory loads (which have the tightest latency requirement).
|
||||
// issuing shared memory loads (which have the tightest latency
|
||||
// requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
|
||||
@ -30,7 +30,8 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
\brief Tools and utils to store a GEMM output in shmem, and to use that
|
||||
output as operandA for another GEMM back-to-back
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@ -55,6 +56,7 @@
|
||||
#include "../epilogue/epilogue_thread_apply_logsumexp.h"
|
||||
#include "../gemm/mma_accum_lambda_iterator.h"
|
||||
#include "../gemm_kernel_utils.h"
|
||||
#include "../iterators/default_warp_iterator_from_smem.h"
|
||||
#include "../iterators/make_residual_last.h"
|
||||
#include "../iterators/transpose_warp_iterator.h"
|
||||
#include "../iterators/warp_iterator_from_smem.h"
|
||||
@ -128,18 +130,22 @@ class AccumulatorSharedStorage {
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
// Maximum value for K
|
||||
int kMaxK,
|
||||
// Maximum K dimension - also the dimension of the shared-memory
|
||||
// holding `OperandA`
|
||||
int kMaxK_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Layout in shared-memory of operand A
|
||||
typename SmemLayoutA,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaBaseFromSharedMemory {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
static constexpr int kMaxK = kMaxK_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
@ -175,8 +181,7 @@ class MmaBaseFromSharedMemory {
|
||||
static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA =
|
||||
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, SmemLayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB =
|
||||
@ -240,14 +245,14 @@ class MmaBaseFromSharedMemory {
|
||||
CUTLASS_DEVICE
|
||||
MmaBaseFromSharedMemory(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage& shared_storage,
|
||||
TensorRefB& b_tile,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
||||
: warp_tile_iterator_B_(b_tile, lane_idx) {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
@ -264,9 +269,8 @@ class NoOpWarpIteratorScale {
|
||||
// in pipelined+multistage MMA implementations we keep an array of fragments.
|
||||
// if we aren't using scaling we don't want to waste registers on fragments
|
||||
// of scale elements, so ideally this would be sized 0.
|
||||
// using size 1 is kind of a hack to get around arrays of zero-sized objects
|
||||
// not being allowed. the compiler is probably smart enough to wipe it out
|
||||
// anyways.
|
||||
// Since arrays of zero-sized objects are not allowed, using size as 1.
|
||||
// The compiler will most likely wipe it out anyways.
|
||||
using Fragment = cutlass::Array<char, 1>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -334,14 +338,13 @@ template <
|
||||
typename Shape_,
|
||||
// BEGIN smem
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA,
|
||||
typename WarpIteratorA_,
|
||||
/// whether or not to perform elementwise multiplication of A
|
||||
// by another matrix (A_scale) that is also kept in shared memory prior
|
||||
// to matmul A @ B
|
||||
bool ScaleOperandA_,
|
||||
// Accumulator type
|
||||
typename AccumulatorSharedStorage,
|
||||
// END smem
|
||||
/// Max GEMM problem size in K dimension
|
||||
int MaxK,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
@ -364,21 +367,24 @@ template <
|
||||
typename Enable = bool>
|
||||
class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
Shape_,
|
||||
AccumulatorSharedStorage::Shape::kN,
|
||||
MaxK,
|
||||
Policy_,
|
||||
2> {
|
||||
2,
|
||||
typename WarpIteratorA_::Layout> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = MmaBaseFromSharedMemory<
|
||||
Shape_,
|
||||
AccumulatorSharedStorage::Shape::kN,
|
||||
MaxK,
|
||||
Policy_,
|
||||
2>;
|
||||
2,
|
||||
typename WarpIteratorA_::Layout>;
|
||||
|
||||
using Shape =
|
||||
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
static constexpr bool ScaleOperandA = ScaleOperandA_;
|
||||
|
||||
using WarpIteratorA = WarpIteratorA_;
|
||||
///< loads fragments of A_scale from shared memory if operand A scaling is
|
||||
///< enabled. otherwise no-op.
|
||||
using WarpIteratorAScale = typename cutlass::platform::conditional<
|
||||
@ -455,19 +461,17 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
/// constructor for MMA with operand A scaling enabled.
|
||||
CUTLASS_DEVICE
|
||||
MmaPipelinedFromSharedMemory(
|
||||
// shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
// warp iterator over A tile held in shared memory
|
||||
WarpIteratorA warp_iter_a,
|
||||
// warp iterator over A_scale tile held in shared memory
|
||||
WarpIteratorAScale warp_iter_a_scale,
|
||||
typename Base::TensorRefA a, // Operand A in shared memory
|
||||
typename Base::TensorRefA a_scale, // Operand A_scale in shared memory
|
||||
typename Base::TensorRefB
|
||||
b_staging, // staging memory for loading tiles of B
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A_(warp_iter_a),
|
||||
warp_tile_iterator_A_scale_(warp_iter_a_scale),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
||||
: Base(b_staging, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A_(a, lane_idx),
|
||||
warp_tile_iterator_A_scale_(a_scale, lane_idx),
|
||||
smem_iterator_B_(b_staging, thread_idx) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
@ -490,17 +494,14 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
MmaPipelinedFromSharedMemory(
|
||||
typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by
|
||||
///< threadblock-scoped GEMM
|
||||
AccumulatorSharedStorage& accumulator_shared_storage,
|
||||
typename Base::TensorRefA a, ///< Operand A in shared memory
|
||||
typename Base::TensorRefB b_staging, ///< staging memory for loading B
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx, ///< ID of each thread within a warp
|
||||
int problem_size_0_n)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
||||
int lane_idx) ///< ID of each thread within a warp
|
||||
: Base(b_staging, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A_(a, lane_idx),
|
||||
smem_iterator_B_(b_staging, thread_idx) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
@ -532,6 +533,9 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
int thread_idx,
|
||||
int problem_size_0_n) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void drain_cp_asyncs() {}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
@ -600,7 +604,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
||||
// issuing shared memory loads (which have the tightest latency requirement).
|
||||
// issuing shared memory loads (which have the tightest latency
|
||||
// requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
@ -621,8 +626,10 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
bool hasNext = true;
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||
if (gemm_k_iterations > 1) {
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -696,8 +703,6 @@ template <
|
||||
// by another matrix (A_scale) that is also kept in shared memory prior
|
||||
// to matmul A @ B
|
||||
bool ScaleOperandA_,
|
||||
// Accumulator type
|
||||
typename AccumulatorSharedStorage,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
@ -718,11 +723,20 @@ template <
|
||||
int kMaxK_,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaMultistageFromSharedMemory
|
||||
: public MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_> {
|
||||
class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
Shape1_,
|
||||
kMaxK_,
|
||||
Policy1_,
|
||||
Stages_,
|
||||
typename WarpIteratorA1_::Layout> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_>;
|
||||
using Base = MmaBaseFromSharedMemory<
|
||||
Shape1_,
|
||||
kMaxK_,
|
||||
Policy1_,
|
||||
Stages_,
|
||||
typename WarpIteratorA1_::Layout>;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
@ -826,20 +840,16 @@ class MmaMultistageFromSharedMemory
|
||||
/// constructor for MMA with operand A scaling enabled.
|
||||
CUTLASS_DEVICE
|
||||
MmaMultistageFromSharedMemory(
|
||||
// shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
// warp level iterator over operand A tile kept in shared memory
|
||||
WarpIteratorA1 warp_tile_iterator_A1,
|
||||
// warp level iterator over operand A elementwise scale tile kept in
|
||||
// shared memory.
|
||||
WarpIteratorAScale warp_tile_iterator_A1_scale,
|
||||
typename Base::TensorRefA a,
|
||||
typename Base::TensorRefA a_scale,
|
||||
typename Base::TensorRefB b_tile,
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A1_(warp_tile_iterator_A1),
|
||||
warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale),
|
||||
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
|
||||
: Base(b_tile, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A1_(a, lane_idx),
|
||||
warp_tile_iterator_A1_scale_(a_scale, lane_idx),
|
||||
smem_iterator_B1_(b_tile, thread_idx),
|
||||
prologue_done_(false) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
@ -864,23 +874,17 @@ class MmaMultistageFromSharedMemory
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
MmaMultistageFromSharedMemory(
|
||||
typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by
|
||||
///< threadblock-scoped GEMM
|
||||
AccumulatorSharedStorage& accumulator_shared_storage,
|
||||
typename Base::TensorRefA a,
|
||||
typename Base::TensorRefB b_tile,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx,
|
||||
///< GEMM0 N is used for accumulator extent
|
||||
int problem_size_0_n)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A1_(
|
||||
accumulator_shared_storage.accum_ref(),
|
||||
lane_idx),
|
||||
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
|
||||
int lane_idx)
|
||||
: Base(b_tile, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A1_(a, lane_idx),
|
||||
smem_iterator_B1_(b_tile, thread_idx),
|
||||
prologue_done_(false) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
@ -920,6 +924,15 @@ class MmaMultistageFromSharedMemory
|
||||
smem_iterator_B1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void drain_cp_asyncs() {
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM
|
||||
// mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(
|
||||
IteratorB1& iterator_B1,
|
||||
@ -1254,100 +1267,11 @@ class MmaMultistageFromSharedMemory
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename RegularWarpIterator,
|
||||
typename Policy,
|
||||
typename Enable = void>
|
||||
struct DefaultWarpIteratorAFromSharedMemory {};
|
||||
|
||||
// TensorOp - Ampere half
|
||||
template <typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
RegularWarpIterator,
|
||||
Policy,
|
||||
typename platform::enable_if<(
|
||||
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
|
||||
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
using WarpShape = cutlass::MatrixShape<32, 32>;
|
||||
|
||||
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element>;
|
||||
};
|
||||
|
||||
// TensorOp - Ampere f32
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
RegularWarpIterator,
|
||||
Policy,
|
||||
typename platform::enable_if<(
|
||||
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
|
||||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
|
||||
using WarpIterator =
|
||||
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
|
||||
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
OpDelta::kRow,
|
||||
kWarpSize>;
|
||||
};
|
||||
|
||||
// TensorOp - Volta
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<16, 16, 4>,
|
||||
RegularWarpIterator,
|
||||
Policy> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
|
||||
using WarpIterator =
|
||||
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
|
||||
// WarpShape::kK>,
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element,
|
||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
|
||||
cutlass::MatrixShape<16, 4>,
|
||||
OpDelta::kRow,
|
||||
kWarpSize>;
|
||||
};
|
||||
|
||||
// Simt
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>,
|
||||
RegularWarpIterator,
|
||||
Policy> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
|
||||
// We just use the same iterator, as we reproduced the same shared-memory
|
||||
// schema. Just modify it to handle non-complete tiles.
|
||||
using WarpIterator = RegularWarpIterator;
|
||||
};
|
||||
|
||||
// Converts a "regular" Mma into their counterpart from shared memory
|
||||
template <
|
||||
typename Mma_,
|
||||
typename AccumulatorSharedStorage,
|
||||
int kMaxK,
|
||||
typename WarpIteratorA_,
|
||||
/// whether or not to apply elementwise multiplication of operand A by
|
||||
/// another matrix in shared memory before usage in A @ B
|
||||
bool kScaleOperandA,
|
||||
@ -1365,6 +1289,7 @@ template <
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
typename WarpIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
@ -1382,7 +1307,8 @@ template <
|
||||
typename TransformA_,
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB_,
|
||||
typename AccumulatorSharedStorage_,
|
||||
// Max MMA problem size K
|
||||
int kMaxK,
|
||||
/// whether or not to apply elementwise multiplication of operand A by
|
||||
/// another matrix in shared memory before usage in A @ B
|
||||
bool kScaleOperandA,
|
||||
@ -1399,12 +1325,10 @@ struct DefaultMmaFromSharedMemory<
|
||||
Policy_,
|
||||
TransformA_,
|
||||
TransformB_>,
|
||||
AccumulatorSharedStorage_,
|
||||
kMaxK,
|
||||
WarpIteratorA_,
|
||||
kScaleOperandA,
|
||||
kTransposeA> {
|
||||
static constexpr int kWarpSize = 32;
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
|
||||
using RegularMma = MmaPipelined<
|
||||
Shape_,
|
||||
IteratorA_,
|
||||
@ -1422,11 +1346,7 @@ struct DefaultMmaFromSharedMemory<
|
||||
using ArchMmaOperator = typename Policy_::Operator;
|
||||
|
||||
static constexpr bool kIsTransposedA = false;
|
||||
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
typename RegularMma::Operator::IteratorA,
|
||||
Policy_>::WarpIterator;
|
||||
using WarpIteratorA = WarpIteratorA_;
|
||||
using IteratorB =
|
||||
typename cutlass::transform::threadblock::MakeIteratorResidualLast<
|
||||
IteratorB_>::Iterator;
|
||||
@ -1435,7 +1355,7 @@ struct DefaultMmaFromSharedMemory<
|
||||
Shape_,
|
||||
WarpIteratorA,
|
||||
kScaleOperandA,
|
||||
AccumulatorSharedStorage_,
|
||||
kMaxK,
|
||||
IteratorB,
|
||||
SmemIteratorB_,
|
||||
ElementC_,
|
||||
@ -1453,6 +1373,7 @@ template <
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
typename WarpIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
@ -1474,7 +1395,7 @@ template <
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
typename AccumulatorSharedStorage_,
|
||||
int kMaxK,
|
||||
/// whether or not to apply elementwise multiplication of operand A by
|
||||
/// another matrix in shared memory before usage in A @ B
|
||||
bool kScaleOperandA,
|
||||
@ -1493,11 +1414,10 @@ struct DefaultMmaFromSharedMemory<
|
||||
Policy_,
|
||||
Stages,
|
||||
SharedMemoryClear>,
|
||||
AccumulatorSharedStorage_,
|
||||
kMaxK,
|
||||
WarpIteratorA_,
|
||||
kScaleOperandA,
|
||||
kTransposeA> {
|
||||
static constexpr int kWarpSize = 32;
|
||||
|
||||
using RegularMma = MmaMultistage<
|
||||
Shape_,
|
||||
IteratorA_,
|
||||
@ -1514,11 +1434,6 @@ struct DefaultMmaFromSharedMemory<
|
||||
|
||||
using WarpShape = typename Policy_::Operator::Shape;
|
||||
using InstructionShape = typename Policy_::Operator::InstructionShape;
|
||||
using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
typename RegularMma::Operator::IteratorA,
|
||||
Policy_>::WarpIterator;
|
||||
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
|
||||
static constexpr bool kIsTransposedA =
|
||||
WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
|
||||
@ -1527,9 +1442,6 @@ struct DefaultMmaFromSharedMemory<
|
||||
typename WarpIteratorTranspose::Iterator,
|
||||
WarpIteratorA_>::type;
|
||||
|
||||
static int constexpr kMaxK = kIsTransposedA
|
||||
? AccumulatorSharedStorage_::Shape::kM
|
||||
: AccumulatorSharedStorage_::Shape::kN;
|
||||
// Reduce the number of stages if we don't need that many
|
||||
static int constexpr kStagesMax =
|
||||
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
|
||||
@ -1543,7 +1455,6 @@ struct DefaultMmaFromSharedMemory<
|
||||
Shape_,
|
||||
WarpIteratorA,
|
||||
kScaleOperandA,
|
||||
AccumulatorSharedStorage_,
|
||||
IteratorB,
|
||||
SmemIteratorB_,
|
||||
RegularMma::kCacheOpB,
|
||||
@ -1751,27 +1662,17 @@ struct B2bGemm<
|
||||
using FragmentC = IteratorC::Fragment;
|
||||
using lse_scalar_t = float;
|
||||
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<32, 32, 4>,
|
||||
scalar_t,
|
||||
SmemAccumulatorLayout>;
|
||||
|
||||
// // Storage in shared-memory for Q.Kt
|
||||
// Storage in shared-memory for Q.Kt
|
||||
using SmemAccumulatorLayout =
|
||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
|
||||
using AccumulatorSharedStorage =
|
||||
cutlass::gemm::threadblock::AccumulatorSharedStorage<
|
||||
ThreadblockShape,
|
||||
scalar_t,
|
||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<
|
||||
16,
|
||||
32>, // typename SmemIteratorD0::TensorLayout,
|
||||
SmemAccumulatorLayout,
|
||||
cutlass::MatrixShape<0, 0> // Padding
|
||||
>;
|
||||
|
||||
using OutputLayout =
|
||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
|
||||
using TensorRef = cutlass::TensorRef<scalar_t, OutputLayout>;
|
||||
using TensorRef = cutlass::TensorRef<scalar_t, SmemAccumulatorLayout>;
|
||||
using Policy = typename IteratorC::Policy;
|
||||
using Element = accum_t;
|
||||
// Those are MmaVoltaTensorOpAccumulatorTileIterator private fields
|
||||
|
||||
@ -115,10 +115,10 @@
|
||||
std::cerr << #PTR " is not correctly aligned\n"; \
|
||||
return false; \
|
||||
}
|
||||
#define XFORMERS_CHECK(COND, ERR) \
|
||||
if (!(COND)) { \
|
||||
std::cerr << #COND " failed\n"; \
|
||||
return false; \
|
||||
#define XFORMERS_CHECK(COND, ERR) \
|
||||
if (!(COND)) { \
|
||||
std::cerr << "'" #COND "' failed: " << ERR << "\n"; \
|
||||
return false; \
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -228,8 +228,17 @@ struct call_conditional<false, TA, TB> {
|
||||
// The cheapest way to do it is just to broadcast it from lane 0
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTLASS_DEVICE int32_t warp_uniform(int32_t value) {
|
||||
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
|
||||
template <typename T>
|
||||
CUTLASS_DEVICE T warp_uniform(T value) {
|
||||
struct {
|
||||
union {
|
||||
T value;
|
||||
uint32_t asInt;
|
||||
};
|
||||
} p;
|
||||
p.value = value;
|
||||
p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
|
||||
return p.value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -0,0 +1,143 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Instanciates the right WarpIterator to read from shared memory
|
||||
The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
|
||||
data dumped with `B2bGemm::accumToSmem`.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "warp_iterator_from_smem.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename RegularWarpIterator,
|
||||
typename Policy,
|
||||
typename Enable = void>
|
||||
struct DefaultWarpIteratorAFromSharedMemory {};
|
||||
|
||||
// TensorOp - Ampere half
|
||||
template <typename RegularWarpIterator, typename Policy, int kInstrK>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, kInstrK>,
|
||||
RegularWarpIterator,
|
||||
Policy,
|
||||
typename platform::enable_if<(
|
||||
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
|
||||
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
using WarpShape = cutlass::MatrixShape<32, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>;
|
||||
|
||||
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element,
|
||||
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>>;
|
||||
};
|
||||
|
||||
// TensorOp - Ampere f32
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
RegularWarpIterator,
|
||||
Policy,
|
||||
typename platform::enable_if<(
|
||||
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
|
||||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
|
||||
using WarpIterator =
|
||||
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
|
||||
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
OpDelta::kRow,
|
||||
kWarpSize>;
|
||||
};
|
||||
|
||||
// TensorOp - Volta
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<16, 16, 4>,
|
||||
RegularWarpIterator,
|
||||
Policy> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
|
||||
using WarpIterator =
|
||||
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
|
||||
// WarpShape::kK>,
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element,
|
||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
|
||||
cutlass::MatrixShape<16, 4>,
|
||||
OpDelta::kRow,
|
||||
kWarpSize>;
|
||||
};
|
||||
|
||||
// Simt
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>,
|
||||
RegularWarpIterator,
|
||||
Policy> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
|
||||
// We just use the same iterator, as we reproduced the same shared-memory
|
||||
// schema. Just modify it to handle non-complete tiles.
|
||||
using WarpIterator = RegularWarpIterator;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -175,7 +175,7 @@ class PredicatedTileAccessIteratorResidualLast<
|
||||
Mask residual_tile_mask;
|
||||
|
||||
/// Parameters object with precomputed internal state
|
||||
Params const& params_;
|
||||
Params params_;
|
||||
|
||||
/// Internal pointer to first access of tile
|
||||
BytePointer pointer_;
|
||||
@ -1018,7 +1018,7 @@ class PredicatedTileAccessIteratorResidualLast<
|
||||
//
|
||||
|
||||
/// Parameters object with precomputed internal state
|
||||
Params const& params_;
|
||||
Params params_;
|
||||
|
||||
/// Internal pointer to first access of tile
|
||||
BytePointer pointer_;
|
||||
|
||||
@ -44,10 +44,12 @@ template <
|
||||
cutlass::gemm::Operand Operand,
|
||||
/// Data type of A elements
|
||||
typename Element,
|
||||
typename InstructionShape,
|
||||
bool kTranspose>
|
||||
struct TransposeWarpIterator<
|
||||
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> {
|
||||
using Iterator =
|
||||
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>;
|
||||
cutlass::gemm::warp::
|
||||
WarpIteratorFromSmem<Operand, Element, InstructionShape, kTranspose>> {
|
||||
using Iterator = cutlass::gemm::warp::
|
||||
WarpIteratorFromSmem<Operand, Element, InstructionShape, !kTranspose>;
|
||||
static bool constexpr kSupportsTranspose = true;
|
||||
};
|
||||
|
||||
@ -56,6 +56,7 @@ template <
|
||||
Operand Operand_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
typename InstructionShape_,
|
||||
bool kTranspose = false>
|
||||
class WarpIteratorFromSmem {
|
||||
public:
|
||||
@ -64,6 +65,9 @@ class WarpIteratorFromSmem {
|
||||
|
||||
/// Operand tag
|
||||
static Operand const kOperand = Operand_;
|
||||
static_assert(
|
||||
kOperand == Operand::kA,
|
||||
"No support for OperandB at the moment");
|
||||
|
||||
/// Basic check
|
||||
static_assert(
|
||||
@ -78,7 +82,11 @@ class WarpIteratorFromSmem {
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = cutlass::MatrixShape<16, 8>;
|
||||
using InstructionShape = InstructionShape_;
|
||||
static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
|
||||
static_assert(
|
||||
InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
|
||||
"Only supports 16x8x8 / 16x8x16");
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||
/// MatrixShape)
|
||||
@ -133,7 +141,9 @@ class WarpIteratorFromSmem {
|
||||
: InstructionShape::kRow);
|
||||
static int constexpr kAccessesInner =
|
||||
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
|
||||
// Number of 32bits tiles to load per `ldmatrix`
|
||||
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
|
||||
static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");
|
||||
|
||||
private:
|
||||
/// Underlying tensor reference
|
||||
@ -153,38 +163,28 @@ class WarpIteratorFromSmem {
|
||||
CUTLASS_HOST_DEVICE
|
||||
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
|
||||
: ref_(ref), iterations_(0) {
|
||||
// See also:
|
||||
// https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
|
||||
// 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
|
||||
// 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
|
||||
int ldsm_vec_num = (lane_id >> 3);
|
||||
if (kOperand == Operand::kA) {
|
||||
origin_ = MatrixCoord(lane_id % 8, 0);
|
||||
static_assert(
|
||||
InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4,
|
||||
"");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow;
|
||||
++inst_m_idx) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction;
|
||||
++access_m_idx) {
|
||||
int access_idx = access_m_idx +
|
||||
kTilesPerInstruction *
|
||||
(inner_idx + kAccessesInner * inst_m_idx);
|
||||
|
||||
MatrixCoord offset(
|
||||
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
||||
inner_idx * 4 * kElementsPerAccess);
|
||||
|
||||
if (access_idx == ldsm_vec_num) {
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
origin_ += offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
InstructionCount::kRow * kTilesPerInstruction == 4,
|
||||
"can't use ldmatrix.x4");
|
||||
int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
|
||||
int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
|
||||
int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
|
||||
MatrixCoord offset(
|
||||
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
||||
inner_idx * 4 * kElementsPerAccess);
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
origin_ += offset;
|
||||
} else {
|
||||
// Note: This is not tested or used
|
||||
origin_ = MatrixCoord(0, lane_id % 8);
|
||||
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -256,17 +256,23 @@ class WarpIteratorFromSmem {
|
||||
using LoadLayout = typename platform::
|
||||
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
|
||||
|
||||
MatrixCoord offset;
|
||||
if (kOperand == Operand::kA) {
|
||||
offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn);
|
||||
} else {
|
||||
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int access_m_idx = 0; access_m_idx <
|
||||
(InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4;
|
||||
++access_m_idx) {
|
||||
MatrixCoord offset;
|
||||
if (kOperand == Operand::kA) {
|
||||
offset = MatrixCoord(
|
||||
access_m_idx * 16, iterations_ * InstructionShape::kColumn);
|
||||
} else {
|
||||
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
|
||||
}
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
cutlass::arch::ldsm<LoadLayout, 4>(
|
||||
access_ptr[access_m_idx], ref_.data() + ref_.offset(offset));
|
||||
}
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
cutlass::arch::ldsm<LoadLayout, 4>(
|
||||
access_ptr[0], ref_.data() + ref_.offset(offset));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
2553
examples/41_fused_multi_head_attention/kernel_backward.h
Normal file
2553
examples/41_fused_multi_head_attention/kernel_backward.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -66,6 +66,7 @@
|
||||
#include "debug_utils.h"
|
||||
#include "epilogue/epilogue_pipelined.h"
|
||||
#include "epilogue/epilogue_rescale_output.h"
|
||||
#include "gemm/custom_mma.h"
|
||||
#include "gemm/find_default_mma.h"
|
||||
#include "gemm/mma_from_smem.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
@ -77,7 +78,7 @@ using namespace gemm_kernel_utils;
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t, typename Arch>
|
||||
constexpr int getWarpsPerSm() {
|
||||
constexpr int getWarpsPerSmFw() {
|
||||
return (
|
||||
Arch::kMinComputeCapability >= 80 &&
|
||||
!cutlass::platform::is_same<scalar_t, float>::value
|
||||
@ -92,6 +93,24 @@ static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// If ToBatchHookType_ is supplied other than this default (which is
|
||||
// never the case in the xformers library) then the user is
|
||||
// defining the logic which each block uses to find its data to work on,
|
||||
// with the advance_to_batch function with the following signature.
|
||||
// It should return false if there is no work to do for this block.
|
||||
// In general this will not work with saving for backward due to fixed layout
|
||||
// for logsumexp and incompatible rngs for dropout, so is likely only useful for
|
||||
// custom inference.
|
||||
struct DefaultToBatchHook {
|
||||
template <typename Params>
|
||||
CUTLASS_DEVICE static bool advance_to_batch(
|
||||
Params&,
|
||||
int64_t& /* q_start */,
|
||||
int64_t& /* k_start */) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
// The datatype of Q/K/V
|
||||
typename scalar_t_,
|
||||
@ -99,13 +118,15 @@ template <
|
||||
typename ArchTag,
|
||||
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
||||
bool isAligned_,
|
||||
int kQueriesPerBlock,
|
||||
int kQueriesPerBlock_,
|
||||
int kKeysPerBlock_,
|
||||
bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock`
|
||||
// upperbound on `max(value.shape[-1], query.shape[-1])`
|
||||
int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
|
||||
// This is quite slower on V100 for some reason
|
||||
// Set to false if you know at compile-time you will never need dropout
|
||||
bool kSupportsDropout_ = true,
|
||||
bool kSupportsBias_ = true>
|
||||
bool kSupportsBias_ = true,
|
||||
typename ToBatchHookType_ = DefaultToBatchHook>
|
||||
struct AttentionKernel {
|
||||
enum CustomMaskType {
|
||||
NoCustomMask = 0,
|
||||
@ -125,11 +146,14 @@ struct AttentionKernel {
|
||||
static constexpr bool kSupportsDropout = kSupportsDropout_;
|
||||
static constexpr bool kSupportsBias = kSupportsBias_;
|
||||
static constexpr int kKeysPerBlock = kKeysPerBlock_;
|
||||
static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
|
||||
static constexpr int kMaxK = kMaxK_;
|
||||
static constexpr bool kIsAligned = isAligned_;
|
||||
static constexpr bool kSingleValueIteration = kSingleValueIteration_;
|
||||
static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
|
||||
static constexpr int32_t kAlignLSE = 32; // block size of backward
|
||||
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
|
||||
cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static constexpr bool kPreloadV =
|
||||
ArchTag::kMinComputeCapability >= 80 && kIsHalf;
|
||||
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
||||
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
||||
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
||||
@ -143,66 +167,67 @@ struct AttentionKernel {
|
||||
// Launch bounds
|
||||
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
|
||||
static constexpr int kMinBlocksPerSm =
|
||||
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
||||
getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
||||
|
||||
struct Params {
|
||||
// Input tensors
|
||||
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
|
||||
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
|
||||
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
|
||||
scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
|
||||
scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
|
||||
scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
|
||||
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
|
||||
int32_t* seqstart_q_ptr = nullptr;
|
||||
int32_t* seqstart_k_ptr = nullptr;
|
||||
|
||||
int32_t* causal_diagonal_ptr = nullptr;
|
||||
int32_t* seqlen_k_ptr = nullptr;
|
||||
uint32_t causal_diagonal_offset = 0;
|
||||
|
||||
// Output tensors
|
||||
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
output_accum_t*
|
||||
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
|
||||
output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
|
||||
// [num_queries, num_heads, head_dim_value]
|
||||
output_accum_t* output_accum_ptr = nullptr;
|
||||
// [num_heads, num_queries] - can be null
|
||||
lse_scalar_t* logsumexp_ptr = nullptr;
|
||||
|
||||
// Scale
|
||||
accum_t scale;
|
||||
accum_t scale = 0.0;
|
||||
|
||||
// Dimensions/strides
|
||||
int32_t head_dim;
|
||||
int32_t head_dim_value;
|
||||
int32_t num_queries;
|
||||
int32_t num_keys;
|
||||
int32_t head_dim = 0;
|
||||
int32_t head_dim_value = 0;
|
||||
int32_t num_queries = 0;
|
||||
int32_t num_keys = 0;
|
||||
int32_t num_keys_absolute = 0;
|
||||
|
||||
uint8_t custom_mask_type = NoCustomMask;
|
||||
|
||||
int32_t q_strideM;
|
||||
int32_t k_strideM;
|
||||
int32_t v_strideM;
|
||||
int32_t q_strideM = 0;
|
||||
int32_t k_strideM = 0;
|
||||
int32_t v_strideM = 0;
|
||||
int32_t bias_strideM = 0;
|
||||
|
||||
int32_t o_strideM = 0;
|
||||
|
||||
// Everything below is only used in `advance_to_block`
|
||||
// and shouldn't use registers
|
||||
int32_t q_strideH;
|
||||
int32_t k_strideH;
|
||||
int32_t v_strideH;
|
||||
int32_t bias_strideH = 0;
|
||||
int32_t q_strideH = 0;
|
||||
int32_t k_strideH = 0;
|
||||
int32_t v_strideH = 0;
|
||||
int64_t bias_strideH = 0;
|
||||
|
||||
int64_t q_strideB;
|
||||
int64_t k_strideB;
|
||||
int64_t v_strideB;
|
||||
int32_t bias_strideB = 0;
|
||||
int64_t q_strideB = 0;
|
||||
int64_t k_strideB = 0;
|
||||
int64_t v_strideB = 0;
|
||||
int64_t bias_strideB = 0;
|
||||
|
||||
int32_t num_batches;
|
||||
int32_t num_heads;
|
||||
int32_t num_batches = 0;
|
||||
int32_t num_heads = 0;
|
||||
|
||||
// dropout
|
||||
bool use_dropout;
|
||||
unsigned long long dropout_batch_head_rng_offset;
|
||||
float dropout_prob;
|
||||
bool use_dropout = false;
|
||||
unsigned long long dropout_batch_head_rng_offset = 0;
|
||||
float dropout_prob = 0.0f;
|
||||
#ifdef HAS_PYTORCH
|
||||
at::PhiloxCudaState rng_engine_inputs;
|
||||
at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
|
||||
#endif
|
||||
|
||||
// Moves pointers to what we should process
|
||||
@ -220,9 +245,17 @@ struct AttentionKernel {
|
||||
head_id * num_queries * num_keys;
|
||||
}
|
||||
|
||||
int64_t q_start, k_start;
|
||||
int64_t q_start = 0, k_start = 0;
|
||||
// Advance to current batch - in case of different sequence lengths
|
||||
if (seqstart_q_ptr != nullptr) {
|
||||
constexpr bool kToBatchHook =
|
||||
!cutlass::platform::is_same<ToBatchHookType_, DefaultToBatchHook>::
|
||||
value;
|
||||
if (kToBatchHook) {
|
||||
// Call out to a custom implementation.
|
||||
if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) {
|
||||
return false;
|
||||
}
|
||||
} else if (seqstart_q_ptr != nullptr) {
|
||||
assert(seqstart_k_ptr != nullptr);
|
||||
seqstart_q_ptr += batch_id;
|
||||
|
||||
@ -285,12 +318,12 @@ struct AttentionKernel {
|
||||
}
|
||||
|
||||
// Custom masking
|
||||
if (causal_diagonal_ptr) {
|
||||
causal_diagonal_offset = causal_diagonal_ptr[batch_id];
|
||||
}
|
||||
if (custom_mask_type == CausalFromBottomRight) {
|
||||
causal_diagonal_offset += num_keys - num_queries;
|
||||
causal_diagonal_offset = num_keys - num_queries;
|
||||
}
|
||||
// We use num_keys_absolute to index into the rng_state
|
||||
// We need this index to match between forward and backwards
|
||||
num_keys_absolute = num_keys;
|
||||
if (custom_mask_type == CausalFromTopLeft ||
|
||||
custom_mask_type == CausalFromBottomRight) {
|
||||
// the bottom row of the current block is query_start + kQueriesPerBlock
|
||||
@ -323,6 +356,7 @@ struct AttentionKernel {
|
||||
|
||||
// Make sure the compiler knows these variables are the same on all
|
||||
// the threads of the warp.
|
||||
// Only worth doing if they could have been modified above.
|
||||
query_ptr = warp_uniform(query_ptr);
|
||||
key_ptr = warp_uniform(key_ptr);
|
||||
value_ptr = warp_uniform(value_ptr);
|
||||
@ -335,8 +369,6 @@ struct AttentionKernel {
|
||||
num_queries = warp_uniform(num_queries);
|
||||
num_keys = warp_uniform(num_keys);
|
||||
num_heads = warp_uniform(num_heads);
|
||||
head_dim = warp_uniform(head_dim);
|
||||
head_dim_value = warp_uniform(head_dim_value);
|
||||
o_strideM = warp_uniform(o_strideM);
|
||||
custom_mask_type = warp_uniform(custom_mask_type);
|
||||
return true;
|
||||
@ -395,14 +427,19 @@ struct AttentionKernel {
|
||||
ThreadblockShape, // ThreadblockShape
|
||||
WarpShape, // WarpShape
|
||||
typename GemmType::InstructionShape, // InstructionShape
|
||||
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
|
||||
// uses too much smem
|
||||
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||
? 4
|
||||
: DefaultConfig::kStages,
|
||||
typename GemmType::Operator // Operator
|
||||
>::DefaultMma;
|
||||
using MmaCore = typename DefaultMma::MmaCore;
|
||||
using IteratorA = typename DefaultMma::IteratorA;
|
||||
using IteratorB = typename DefaultMma::IteratorB;
|
||||
using Mma = typename DefaultMma::ThreadblockMma;
|
||||
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
|
||||
using Mma = typename cutlass::platform::conditional<
|
||||
kSingleValueIteration,
|
||||
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
|
||||
DefaultThreadblockMma>::type;
|
||||
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
||||
typename Mma::Operator::IteratorC,
|
||||
accum_t,
|
||||
@ -475,14 +512,23 @@ struct AttentionKernel {
|
||||
typename GemmType::InstructionShape,
|
||||
typename DefaultConfig::EpilogueOutputOp,
|
||||
void, // ThreadblockSwizzle - not used
|
||||
DefaultConfig::kStages,
|
||||
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||
? 4
|
||||
: DefaultConfig::kStages,
|
||||
false, // SplitKSerial
|
||||
typename GemmType::Operator>;
|
||||
|
||||
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
||||
DefaultWarpIteratorAFromSharedMemory<
|
||||
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
|
||||
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
|
||||
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
|
||||
typename DefaultGemm::Mma::Policy>::WarpIterator;
|
||||
using DefaultMmaFromSmem =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||
typename DefaultGemm::Mma,
|
||||
typename MM0::AccumulatorSharedStorage,
|
||||
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
|
||||
WarpIteratorA,
|
||||
false>; // kScaleOperandA
|
||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
@ -500,10 +546,6 @@ struct AttentionKernel {
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_accum_t>;
|
||||
|
||||
struct SharedStorageMM1 {
|
||||
typename Mma::SharedStorage mm;
|
||||
};
|
||||
};
|
||||
|
||||
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
|
||||
@ -515,6 +557,9 @@ struct AttentionKernel {
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> mi;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
|
||||
addition_storage;
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||
@ -524,7 +569,7 @@ struct AttentionKernel {
|
||||
typename MM0::BiasLoader::SmemTile bias;
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
};
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::Mma::SharedStorage mm1;
|
||||
};
|
||||
|
||||
union {
|
||||
@ -546,7 +591,7 @@ struct AttentionKernel {
|
||||
typename MM0::BiasLoader::SmemTile bias;
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
};
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::Mma::SharedStorage mm1;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
@ -573,30 +618,33 @@ struct AttentionKernel {
|
||||
if (kSupportsBias) {
|
||||
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
|
||||
XFORMERS_CHECK(
|
||||
p.bias_strideB % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned");
|
||||
p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned (strideB)");
|
||||
XFORMERS_CHECK(
|
||||
p.bias_strideH % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned");
|
||||
p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned (strideH)");
|
||||
XFORMERS_CHECK(
|
||||
p.bias_strideM % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned");
|
||||
}
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
p.q_strideM % kAlignmentQ == 0,
|
||||
"query is not correctly aligned (strideM)");
|
||||
XFORMERS_CHECK(
|
||||
p.k_strideM % kAlignmentK == 0, "key is not correctly aligned");
|
||||
p.k_strideM % kAlignmentK == 0,
|
||||
"key is not correctly aligned (strideM)");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideM % kAlignmentV == 0, "value is not correctly aligned");
|
||||
p.v_strideM % kAlignmentV == 0,
|
||||
"value is not correctly aligned (strideM)");
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
|
||||
"query is not correctly aligned (strideH)");
|
||||
XFORMERS_CHECK(
|
||||
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
|
||||
p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
|
||||
"key is not correctly aligned (strideH)");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask,
|
||||
"`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal");
|
||||
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
|
||||
"value is not correctly aligned (strideH)");
|
||||
XFORMERS_CHECK(
|
||||
p.custom_mask_type < NumCustomMaskTypes,
|
||||
"invalid value for `custom_mask_type`");
|
||||
@ -613,11 +661,13 @@ struct AttentionKernel {
|
||||
auto& m_prime = shared_storage.m_prime;
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
auto& mi = shared_storage.mi;
|
||||
auto& out_rescale = shared_storage.out_rescale;
|
||||
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
|
||||
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (thread_id() < kQueriesPerBlock) {
|
||||
s_prime[thread_id()] = accum_t(0);
|
||||
out_rescale[thread_id()] = accum_t(1.0);
|
||||
m_prime[thread_id()] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
@ -689,7 +739,7 @@ struct AttentionKernel {
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
MM1::Mma::prologue(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.mm1,
|
||||
iterator_V,
|
||||
thread_id(),
|
||||
problem_size_1_k);
|
||||
@ -733,7 +783,7 @@ struct AttentionKernel {
|
||||
thread_id(),
|
||||
tb_offset_B);
|
||||
|
||||
auto my_warp_id = warp_id();
|
||||
auto my_warp_id = warp_uniform(warp_id());
|
||||
auto my_lane_id = lane_id();
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
@ -753,6 +803,8 @@ struct AttentionKernel {
|
||||
|
||||
if (kPreloadV) {
|
||||
prologueV(0);
|
||||
} else {
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
}
|
||||
|
||||
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
||||
@ -787,7 +839,7 @@ struct AttentionKernel {
|
||||
|
||||
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
|
||||
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
my_lane_id, my_warp_id, iteratorC_tile_offset);
|
||||
MM0::AccumLambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {},
|
||||
@ -811,7 +863,7 @@ struct AttentionKernel {
|
||||
(query_start + p.causal_diagonal_offset)) {
|
||||
auto query_start = blockIdx.x * kQueriesPerBlock;
|
||||
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
my_lane_id, my_warp_id, iteratorC_tile_offset);
|
||||
int32_t last_col;
|
||||
MM0::AccumLambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
@ -830,30 +882,23 @@ struct AttentionKernel {
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
}
|
||||
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
p.num_keys - iter_key_start >= kKeysPerBlock,
|
||||
kFullColumns,
|
||||
([&] {
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also does accum[i] <- exp(accum[i] - mi)
|
||||
iterative_softmax<
|
||||
typename MM0::Mma::Operator::IteratorC,
|
||||
kFullColumns,
|
||||
kIsFirst>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
lane_id(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
p.num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
kSupportsBias ? 1.0f : p.scale);
|
||||
}));
|
||||
}));
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also does accum[i] <- exp(accum[i] - mi)
|
||||
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
out_rescale,
|
||||
shared_storage.addition_storage,
|
||||
my_lane_id,
|
||||
thread_id(),
|
||||
my_warp_id,
|
||||
p.num_keys - iter_key_start,
|
||||
iter_key_start == 0,
|
||||
iteratorC_tile_offset,
|
||||
kSupportsBias ? 1.0f : p.scale);
|
||||
|
||||
// Output results to shared-memory
|
||||
int warp_idx_mn_0 = my_warp_id %
|
||||
@ -904,7 +949,7 @@ struct AttentionKernel {
|
||||
curandStatePhilox4_32_10_t curand_state = curand_state_init;
|
||||
skipahead(
|
||||
static_cast<unsigned long long>(
|
||||
(query_start + thread_i) * p.num_keys +
|
||||
(query_start + thread_i) * p.num_keys_absolute +
|
||||
(iter_key_start + thread_start_j)),
|
||||
&curand_state);
|
||||
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
|
||||
@ -958,12 +1003,14 @@ struct AttentionKernel {
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
typename MM1::Mma mma_pv(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.si,
|
||||
// operand A: Pij_dropped in shared memory
|
||||
shared_storage.after_mm0.si.accum_ref(),
|
||||
// operand B: shared memory staging area for Vj, which is loaded
|
||||
// from global memory
|
||||
shared_storage.after_mm0.mm1.operand_B_ref(),
|
||||
(int)thread_id(),
|
||||
(int)warp_id(),
|
||||
(int)lane_id(),
|
||||
(int)problem_size_1_k);
|
||||
(int)my_warp_id,
|
||||
(int)my_lane_id);
|
||||
mma_pv.set_prologue_done(kPreloadV);
|
||||
if (!kKeepOutputInRF) {
|
||||
accum_o.clear();
|
||||
@ -976,6 +1023,7 @@ struct AttentionKernel {
|
||||
}
|
||||
|
||||
if (!kKeepOutputInRF) {
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
DISPATCH_BOOL(
|
||||
iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
@ -1027,12 +1075,12 @@ struct AttentionKernel {
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
my_warp_id,
|
||||
my_lane_id);
|
||||
epilogue(rescale, dest_iter, accum_o, source_iter);
|
||||
}));
|
||||
}));
|
||||
@ -1076,12 +1124,13 @@ struct AttentionKernel {
|
||||
typename MM1::OutputTileIteratorAccum // source tile
|
||||
>;
|
||||
auto dest_iter = createOutputIter(0);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
epilogue(rescale, dest_iter, accum_o);
|
||||
}
|
||||
|
||||
@ -1091,8 +1140,9 @@ struct AttentionKernel {
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
|
||||
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (thread_id() < p.num_queries) {
|
||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
|
||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
|
||||
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
||||
} else if (thread_id() < lse_dim) {
|
||||
p.logsumexp_ptr[thread_id()] =
|
||||
@ -1101,20 +1151,21 @@ struct AttentionKernel {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename WarpIteratorC,
|
||||
bool kFullColumns,
|
||||
bool kIsFirst>
|
||||
template <typename WarpIteratorC>
|
||||
CUTLASS_DEVICE static void iterative_softmax(
|
||||
typename WarpIteratorC::Fragment& frag_o, // output so far
|
||||
typename WarpIteratorC::Fragment& frag,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
|
||||
addition_storage,
|
||||
int8_t lane_id,
|
||||
int8_t thread_id,
|
||||
int8_t warp_id,
|
||||
int16_t max_col,
|
||||
int max_col,
|
||||
bool is_first,
|
||||
typename WarpIteratorC::TensorCoord const& tile_offset,
|
||||
float scaling) {
|
||||
/* Iterates on the accumulator and corresponding position on result matrix
|
||||
@ -1135,12 +1186,11 @@ struct AttentionKernel {
|
||||
kWarpSize>::Iterator;
|
||||
// Convert to `accum_t` (rather than double)
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (!kIsFirst) {
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
m_prime[thread_id] = mi[thread_id];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
|
||||
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
|
||||
|
||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
auto lane_offset =
|
||||
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
||||
@ -1154,46 +1204,64 @@ struct AttentionKernel {
|
||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (kFullColumns || accum_n < max_col) {
|
||||
if (accum_n < max_col) {
|
||||
max = cutlass::fast_max(max, frag[idx]);
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {
|
||||
// Having 4x atomicMax seems faster than reduce within warp
|
||||
// first...
|
||||
atomicMaxFloat(&mi[accum_m], max * scaling);
|
||||
atomicMaxFloat(&mi[accum_m], max);
|
||||
});
|
||||
}
|
||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
// Make sure we all share the update values for `mi`
|
||||
__syncthreads();
|
||||
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
|
||||
m_prime[thread_id] = m_prime_exp;
|
||||
s_prime[thread_id] *= m_prime_exp;
|
||||
// Doing this `exp` is quite expensive. Let's
|
||||
// split it across the warps
|
||||
bool restore_mi_to_minus_inf = false;
|
||||
if (lane_id < kLinesPerWarp) {
|
||||
int id = warp_id * kLinesPerWarp + lane_id;
|
||||
auto m_prime_id = m_prime[id];
|
||||
auto mi_id = mi[id];
|
||||
bool changed = m_prime_id < mi_id; // `false` if both are -inf
|
||||
if (changed) {
|
||||
auto m_prime_exp = exp2f(m_prime_id - mi_id);
|
||||
out_rescale[id] = m_prime_exp;
|
||||
s_prime[id] *= m_prime_exp;
|
||||
} else {
|
||||
// Only when bias is enabled, it's possible that all the first values
|
||||
// of attention are masked to `-inf`. In that case we want to avoid
|
||||
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
|
||||
if (kSupportsBias &&
|
||||
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
|
||||
restore_mi_to_minus_inf = true;
|
||||
mi[id] = 0.0f;
|
||||
}
|
||||
out_rescale[id] = 1.0f;
|
||||
}
|
||||
}
|
||||
__syncthreads(); // Update output fragments
|
||||
if (kKeepOutputInRF && !kIsFirst) {
|
||||
accum_t mp;
|
||||
if (kKeepOutputInRF && !is_first) {
|
||||
accum_t line_rescale;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mp = m_prime[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
|
||||
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag_o[idx] = frag_o[idx] * line_rescale;
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
__syncthreads();
|
||||
}
|
||||
// Update accum_m, accum_n, ...
|
||||
{
|
||||
accum_t mi_row, total_row;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
|
||||
[&](int accum_m) { mi_row = mi[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag[idx] = (kFullColumns || accum_n < max_col)
|
||||
? exp2f(frag[idx] - mi_row)
|
||||
: accum_t(0.0);
|
||||
frag[idx] =
|
||||
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
LambdaIterator::iterateRows(
|
||||
@ -1205,10 +1273,30 @@ struct AttentionKernel {
|
||||
lane_id, total_row, [](accum_t a, accum_t b) {
|
||||
return a + b;
|
||||
})) {
|
||||
atomicAdd(&s_prime[accum_m], total_row);
|
||||
// NOTE: we could atomically add `total_row` to `s_prime`, but
|
||||
// it's faster (and deterministic) to avoid atomics here
|
||||
addition_storage
|
||||
[accum_m + kQueriesPerBlock * tile_offset.column()] =
|
||||
total_row;
|
||||
}
|
||||
});
|
||||
}
|
||||
__syncthreads();
|
||||
if (lane_id < kLinesPerWarp) {
|
||||
int id = warp_id * kLinesPerWarp + lane_id;
|
||||
accum_t total_row = s_prime[id];
|
||||
if (restore_mi_to_minus_inf) {
|
||||
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
|
||||
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
} else {
|
||||
m_prime[id] = mi[id];
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
|
||||
total_row += addition_storage[id + kQueriesPerBlock * i];
|
||||
}
|
||||
s_prime[id] = total_row;
|
||||
}
|
||||
}
|
||||
|
||||
static CUTLASS_DEVICE int8_t lane_id() {
|
||||
|
||||
112
examples/41_fused_multi_head_attention/piped_subprocess.py
Normal file
112
examples/41_fused_multi_head_attention/piped_subprocess.py
Normal file
@ -0,0 +1,112 @@
|
||||
from typing import List
|
||||
import torch
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
TORCH_DTYPE_NAME = {
|
||||
torch.float32: "f32",
|
||||
torch.float16: "f16",
|
||||
torch.bfloat16: "b16"
|
||||
}
|
||||
NAME_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_NAME.items()}
|
||||
|
||||
def _tensor_from_storage(tensor: torch.Tensor, dtype) -> torch.Tensor:
|
||||
# PyTorch >= 2.0
|
||||
if hasattr(tensor, 'untyped_storage'):
|
||||
return torch.tensor([], dtype=dtype).set_(tensor.untyped_storage())
|
||||
return torch.tensor([], dtype=dtype).set_(tensor.storage().untyped())
|
||||
|
||||
class PipedSubprocess:
|
||||
def __init__(self, binary: str) -> None:
|
||||
self.binary = binary
|
||||
self.tempdir_ctx = tempfile.TemporaryDirectory()
|
||||
|
||||
def __enter__(self) -> "PipedSubprocess":
|
||||
self.subp = subprocess.Popen(self.binary, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True, bufsize=0)
|
||||
self.tempdir = self.tempdir_ctx.__enter__()
|
||||
self.file_counter = 0
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
self.tempdir_ctx.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def temp_filename(self, suffix: str) -> str:
|
||||
self.file_counter += 1
|
||||
return os.path.join(self.tempdir, f"{self.file_counter}{suffix}")
|
||||
|
||||
def write(self, *args) -> None:
|
||||
for a in args:
|
||||
self.subp.stdin.write(str(a) + " ")
|
||||
|
||||
def writeTensor(self, tensor: torch.Tensor, name: str, stride_names: List[str]) -> None:
|
||||
print(f"Py ->C++: {TORCH_DTYPE_NAME[tensor.dtype]}:{name}")
|
||||
tensor_u8 = _tensor_from_storage(tensor, torch.uint8)
|
||||
self.write("tensor_begin", f"{TORCH_DTYPE_NAME[tensor.dtype]}:{name}", tensor_u8.shape[0])
|
||||
filename = self.temp_filename(f"{name}.tensor")
|
||||
assert tensor.storage_offset() == 0
|
||||
with open(filename, "wb+") as fd:
|
||||
fd.write(bytes(tensor_u8.numpy()))
|
||||
self.write("file", filename)
|
||||
self.write("tensor_end")
|
||||
|
||||
for stride_name, stride_value in zip(stride_names, tensor.stride()):
|
||||
self.write(stride_name, stride_value)
|
||||
|
||||
def readTensor(self, name, stride_name, shape) -> torch.Tensor:
|
||||
tmpfile = self.temp_filename(f"{name}.tensor")
|
||||
self.write("tmpfile", tmpfile)
|
||||
|
||||
self.readExpect("tensor_begin")
|
||||
dtype_str, name = self.read().split(":")
|
||||
print(f"C++->Py : {dtype_str}:{name}")
|
||||
u8len = int(self.read())
|
||||
dtype = NAME_TORCH_DTYPE[dtype_str]
|
||||
|
||||
self.readExpect("file")
|
||||
self.readExpect(tmpfile)
|
||||
|
||||
with open(tmpfile, "rb") as fd:
|
||||
data = fd.read(u8len)
|
||||
# `np.array` is not strictly needed, but avoids a torch warning
|
||||
tensor_u8 = torch.frombuffer(np.array(data), dtype=torch.uint8, count=u8len)
|
||||
self.readExpect("tensor_end")
|
||||
|
||||
tensor = _tensor_from_storage(tensor_u8, dtype)
|
||||
strides = []
|
||||
for sn in stride_name:
|
||||
self.readExpect(sn)
|
||||
strides.append(int(self.read()))
|
||||
if len(strides) != shape:
|
||||
strides.append(1)
|
||||
assert len(strides) == len(shape), name
|
||||
return torch.as_strided(tensor, shape, strides)
|
||||
|
||||
def readNamed(self, name: str):
|
||||
self.readExpect(name)
|
||||
return self.read()
|
||||
|
||||
def readExpect(self, what: str) -> None:
|
||||
r = self.read()
|
||||
if r != what:
|
||||
raise ValueError(f"Read {r} but expected {what}")
|
||||
|
||||
def read(self):
|
||||
read_all = []
|
||||
# Skip initial whitespace
|
||||
while True:
|
||||
r = self.subp.stdout.read(1)
|
||||
if r not in [' ', "\n"]:
|
||||
read_all.append(r)
|
||||
break
|
||||
# Read data
|
||||
while True:
|
||||
r = self.subp.stdout.read(1)
|
||||
if r in [' ', "\n"]:
|
||||
break
|
||||
read_all.append(r)
|
||||
return ''.join(read_all)
|
||||
|
||||
@ -29,6 +29,8 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
@ -434,14 +434,6 @@ class gen_device:
|
||||
" if (result != cudaSuccess) {\n" + \
|
||||
" return Status::kErrorInternal;\n" + \
|
||||
" }\n" + \
|
||||
"\n" + \
|
||||
" result = cudaFuncSetAttribute(\n" + \
|
||||
" Kernel<B2bGemmKernel>,\n" + \
|
||||
" cudaFuncAttributePreferredSharedMemoryCarveout, 100);\n" + \
|
||||
"\n" + \
|
||||
" if (result != cudaSuccess) {\n" + \
|
||||
" return Status::kErrorInternal;\n" + \
|
||||
" }\n" + \
|
||||
" }\n" + \
|
||||
" cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);\n" + \
|
||||
" result = cudaGetLastError();\n" + \
|
||||
|
||||
@ -331,7 +331,7 @@ class gen_Kernel:
|
||||
operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n"
|
||||
|
||||
|
||||
operator_code += " " + "int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);\n"
|
||||
operator_code += " " + "int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);\n"
|
||||
operator_code += " " + "int lane_idx = threadIdx.x % 32;\n"
|
||||
|
||||
for i in range (self.b2bnum - 1):
|
||||
|
||||
@ -159,7 +159,7 @@ class DualGemm {
|
||||
using Mma0 = typename cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape, WarpShape,
|
||||
ThreadblockShape, WarpShape,
|
||||
InstructionShape, Stages, Operator>::ThreadblockMma;
|
||||
using Mma1 = typename cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB,
|
||||
@ -348,7 +348,7 @@ public:
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
|
||||
@ -167,10 +167,10 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n";
|
||||
|
||||
bool pass = nonFusedGemm.run(
|
||||
problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
true /* is_profiling */
|
||||
);
|
||||
@ -248,10 +248,10 @@ bool run_fused_gemm_f16_sm80_shmem() {
|
||||
std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n";
|
||||
|
||||
bool passed = fusedGemm.run(
|
||||
problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1
|
||||
);
|
||||
|
||||
@ -301,11 +301,11 @@ bool run_batched_fused_gemm_f16_sm80_shmem() {
|
||||
std::cout << "Running Batched Fused FP16 TN GEMMs + Epilogue2...\n";
|
||||
|
||||
bool passed = fusedGemm.run(
|
||||
batch_problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
batch_problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
kBatchCount,
|
||||
false, /* broadcast_b1 */
|
||||
false /* is_profiling */
|
||||
@ -358,11 +358,11 @@ bool run_broadcast_fused_gemm_f16_sm80_shmem() {
|
||||
std::cout << "Running Broadcast Fused FP16 TN GEMMs + Epilogue2...\n";
|
||||
|
||||
bool passed = fusedGemm.run(
|
||||
problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
1, /* batch_count */
|
||||
true, /* broadcast_b1 */
|
||||
true /* is_profiling */
|
||||
@ -415,11 +415,11 @@ bool run_batched_broadcast_fused_gemm_f16_sm80_shmem() {
|
||||
std::cout << "Running Batch Broadcast Fused FP16 TN GEMMs + Epilogue2...\n";
|
||||
|
||||
bool passed = fusedGemm.run(
|
||||
batch_problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
batch_problem_size,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
kBatchCount,
|
||||
true, /* broadcast_b1 */
|
||||
false /* is_profiling */
|
||||
@ -444,11 +444,11 @@ int main() {
|
||||
};
|
||||
|
||||
std::string test_name = (
|
||||
"dual-gemm f16 bias=" +
|
||||
std::to_string(kUseBias) +
|
||||
" split_k_serial=" +
|
||||
"dual-gemm f16 bias=" +
|
||||
std::to_string(kUseBias) +
|
||||
" split_k_serial=" +
|
||||
std::to_string(kSplitKSerial) +
|
||||
" batch_count=" +
|
||||
" batch_count=" +
|
||||
std::to_string(kBatchCount)
|
||||
);
|
||||
|
||||
|
||||
@ -45,6 +45,7 @@
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
|
||||
@ -356,13 +357,13 @@ struct NonFusedDualGemmRun
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
@ -564,22 +565,22 @@ struct DualFusedGemmRun
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementA,
|
||||
typename DualGemm::LayoutA> tensor_A0(
|
||||
std::is_same<typename DualGemm::LayoutA, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutA, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementB,
|
||||
typename DualGemm::LayoutB0> tensor_B0(
|
||||
std::is_same<typename DualGemm::LayoutB0, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutB0, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_C0(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
@ -589,22 +590,22 @@ struct DualFusedGemmRun
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_D0(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> reference_D0(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementB,
|
||||
typename DualGemm::LayoutB1> tensor_B1(
|
||||
std::is_same<typename DualGemm::LayoutB1, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutB1, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
|
||||
if (broadcast_b1) {
|
||||
tensor_B1.resize({problem_size.k(), batch_count});
|
||||
@ -613,8 +614,8 @@ struct DualFusedGemmRun
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_C1(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
@ -624,29 +625,29 @@ struct DualFusedGemmRun
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_D1(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> tensor_D2(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> reference_D1(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename DualGemm::ElementC,
|
||||
typename DualGemm::LayoutC> reference_D2(
|
||||
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
||||
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
||||
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
@ -712,16 +713,16 @@ struct DualFusedGemmRun
|
||||
ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)};
|
||||
}
|
||||
typename DualGemm::Arguments arguments{
|
||||
(batch_count > 1 ?
|
||||
cutlass::gemm::DualGemmMode::kBatched :
|
||||
(batch_count > 1 ?
|
||||
cutlass::gemm::DualGemmMode::kBatched :
|
||||
cutlass::gemm::DualGemmMode::kGemm),
|
||||
problem_size,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
ref_B0,
|
||||
DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref,
|
||||
(broadcast_b1 ?
|
||||
typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) :
|
||||
(broadcast_b1 ?
|
||||
typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) :
|
||||
tensor_B1.device_ref()),
|
||||
ref_B1,
|
||||
DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref,
|
||||
@ -793,15 +794,15 @@ struct DualFusedGemmRun
|
||||
using GemmUniversal0 = cutlass::gemm::device::GemmUniversal<
|
||||
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
||||
typename DualGemm::ElementB, typename DualGemm::LayoutB0,
|
||||
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
||||
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
GemmUniversal0 reference_gemm0;
|
||||
|
||||
typename GemmUniversal0::Arguments args0 {
|
||||
(batch_count > 1 ?
|
||||
cutlass::gemm::GemmUniversalMode::kBatched :
|
||||
(batch_count > 1 ?
|
||||
cutlass::gemm::GemmUniversalMode::kBatched :
|
||||
cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
problem_size,
|
||||
batch_count,
|
||||
@ -828,15 +829,15 @@ struct DualFusedGemmRun
|
||||
using GemmUniversal1 = cutlass::gemm::device::GemmUniversal<
|
||||
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
||||
typename DualGemm::ElementB, typename DualGemm::LayoutB1,
|
||||
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
||||
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
GemmUniversal1 reference_gemm1;
|
||||
|
||||
typename GemmUniversal1::Arguments args1 {
|
||||
(batch_count > 1 ?
|
||||
cutlass::gemm::GemmUniversalMode::kBatched :
|
||||
(batch_count > 1 ?
|
||||
cutlass::gemm::GemmUniversalMode::kBatched :
|
||||
cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
problem_size,
|
||||
batch_count,
|
||||
@ -861,7 +862,7 @@ struct DualFusedGemmRun
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
|
||||
@ -300,7 +300,7 @@ struct DualGemm {
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A0 = static_cast<ElementA *>(params.ref_A0.data());
|
||||
ElementA *ptr_A0 = static_cast<ElementA *>(params.ref_A0.data());
|
||||
ElementB *ptr_B0 = static_cast<ElementB *>(params.ref_B0.data());
|
||||
ElementB *ptr_B1 = static_cast<ElementB *>(params.ref_B1.data());
|
||||
|
||||
@ -309,7 +309,7 @@ struct DualGemm {
|
||||
//
|
||||
if (params.mode == DualGemmMode::kGemm) {
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
@ -364,7 +364,7 @@ struct DualGemm {
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
@ -413,11 +413,11 @@ struct DualGemm {
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C0 = static_cast<ElementC *>(params.ref_C0.data());
|
||||
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
|
||||
ElementC *ptr_D0 = static_cast<ElementC *>(params.ref_D0.data());
|
||||
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
|
||||
ElementC *ptr_D2 = static_cast<ElementC *>(params.ref_D2.data());
|
||||
ElementC *ptr_C0 = static_cast<ElementC *>(params.ref_C0.data());
|
||||
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
|
||||
ElementC *ptr_D0 = static_cast<ElementC *>(params.ref_D0.data());
|
||||
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
|
||||
ElementC *ptr_D2 = static_cast<ElementC *>(params.ref_D2.data());
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
@ -425,7 +425,7 @@ struct DualGemm {
|
||||
if (params.mode == DualGemmMode::kGemm) {
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
|
||||
@ -759,13 +759,10 @@ public:
|
||||
accum1 = plus_accum(accum1, tmp_accum1);
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -233,6 +233,17 @@ struct Options {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Filter size passed through command line does not match filter size template parameter
|
||||
if (filter_size.h() != FilterShape::kRow || filter_size.w() != FilterShape::kColumn) {
|
||||
std::cerr << "Filter size passed in (" << filter_size.h() << "x" << filter_size.w() << ") "
|
||||
<< "must match the FilterShape template parameter of the convolution "
|
||||
<< "(" << FilterShape::kRow << "x" << FilterShape::kColumn << "). "
|
||||
<< "To use the filter shape passed in, change the FilterShape template "
|
||||
<< "parameter and recompile this example."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -319,9 +330,9 @@ struct Options {
|
||||
"table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=32 "
|
||||
<< "$ ./examples/46_depthwise_simt_conv2dfprop/46_depthwise_simt_conv2dfprop --n=32 "
|
||||
"--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n"
|
||||
<< "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=1 "
|
||||
<< "$ ./examples/46_depthwise_simt_conv2dfprop/46_depthwise_simt_conv2dfprop --n=1 "
|
||||
"--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
@ -515,14 +526,13 @@ Result profile_convolution(Options const &options) {
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue> >(problem_size,
|
||||
tensor_a.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_d.host_ref(),
|
||||
options.alpha,
|
||||
options.beta);
|
||||
ElementAccumulator >(problem_size,
|
||||
tensor_a.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_d.host_ref(),
|
||||
options.alpha,
|
||||
options.beta);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_d.sync_host();
|
||||
|
||||
@ -33,3 +33,7 @@ cutlass_example_add_executable(
|
||||
ampere_gemm_universal_streamk.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
47_ampere_gemm_universal_streamk_broadcast
|
||||
ampere_gemm_universal_streamk_broadcast.cu
|
||||
)
|
||||
|
||||
@ -495,7 +495,7 @@ int main(int argc, const char **argv)
|
||||
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
|
||||
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
|
||||
|
||||
// Fill matrix A on host with uniform-random data [4, -4]
|
||||
// Fill matrix A on host with uniform-random data [-2, 2]
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_a.host_view(),
|
||||
1,
|
||||
@ -503,7 +503,7 @@ int main(int argc, const char **argv)
|
||||
ElementA(-2),
|
||||
0);
|
||||
|
||||
// Fill matrix B on host with uniform-random data [4, -4]
|
||||
// Fill matrix B on host with uniform-random data [-2, 2]
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_b.host_view(),
|
||||
1,
|
||||
@ -511,7 +511,7 @@ int main(int argc, const char **argv)
|
||||
ElementB(-2),
|
||||
0);
|
||||
|
||||
// Fill matrix C on host with uniform-random data [4, -4]
|
||||
// Fill matrix C on host with uniform-random data [-2, 2]
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_c.host_view(),
|
||||
1,
|
||||
|
||||
@ -0,0 +1,738 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/***************************************************************************************************
|
||||
Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the
|
||||
"classic data-parallel" and "Split-K" decompositions + residual add.
|
||||
|
||||
For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition
|
||||
for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598)
|
||||
|
||||
Requires NVIDIA Ampere or newer device (SM80+).
|
||||
|
||||
- To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100)
|
||||
|
||||
cutlass$ sudo nvidia-smi -pm 1 -i 0
|
||||
|
||||
cutlass$ sudo nvidia-smi -i 0 -pl 400
|
||||
|
||||
cutlass$ sudo nvidia-smi -i 0 -lgc 1005
|
||||
|
||||
- Build and run:
|
||||
|
||||
cutlass$ mkdir build
|
||||
|
||||
cutlass$ cd build
|
||||
|
||||
cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80
|
||||
|
||||
cutlass/build$ make 47_ampere_gemm_universal_streamk_broadcast
|
||||
|
||||
cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk_broadcast
|
||||
|
||||
- Reset clocks when done:
|
||||
|
||||
cutlass$ sudo nvidia-smi -rgc
|
||||
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_foreach.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C1/C2/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)
|
||||
|
||||
// Output matrix configuration
|
||||
using ElementOutput = cutlass::half_t; // Element type for output matrix operands
|
||||
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
|
||||
// constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)
|
||||
|
||||
// Multiply-accumulate blocking/pipelining details
|
||||
using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation
|
||||
using ElementCompute = cutlass::half_t; // Element type for compute
|
||||
using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape)
|
||||
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
|
||||
constexpr int EVTEpilogueStages = 1; // Number of epilogue stages in EVT
|
||||
|
||||
// Residual block configuration
|
||||
|
||||
// Epilogue output operator
|
||||
/// Using LinearCombinationResidualBlock
|
||||
/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock<
|
||||
ElementOutput, // Element type for output matrix
|
||||
ElementAccumulator, // Element type from internal accumulation
|
||||
ElementCompute, // Element type from internal accumulation
|
||||
ElementC, // Element type for C1/C2/D matrix operands
|
||||
AlignmentC, // Memory access granularity of C and D matrix in units of elements
|
||||
cutlass::epilogue::thread::Identity, // Activation
|
||||
cutlass::plus, // Binary operation 1
|
||||
cutlass::epilogue::thread::Identity, // Unary operation
|
||||
cutlass::plus // Binary operation 2
|
||||
>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
// Classic data-parallel device GEMM implementation type
|
||||
using DeviceGemmBasic = cutlass::gemm::device::GemmUniversalWithBroadcast<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
NumStages,
|
||||
AlignmentA,
|
||||
AlignmentB>;
|
||||
|
||||
// StreamK device GEMM implementation type with EVT
|
||||
using namespace cute;
|
||||
|
||||
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
ElementC,
|
||||
AlignmentC,
|
||||
EVTEpilogueStages
|
||||
>;
|
||||
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, ElementC,
|
||||
cute::Stride<_0, _1, int32_t> // StrideMNL
|
||||
>;
|
||||
|
||||
using C1 = cutlass::epilogue::threadblock::VisitorAuxLoad<
|
||||
OutputTileThreadMap, ElementC,
|
||||
cute::Stride<int64_t, _1, int64_t> // StrideMNL
|
||||
>;
|
||||
|
||||
using C2 = cutlass::epilogue::threadblock::VisitorAuxLoad<
|
||||
OutputTileThreadMap, ElementC,
|
||||
cute::Stride<int64_t, _1, int64_t> // StrideMNL
|
||||
>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::plus, ElementCompute, ElementCompute,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
|
||||
Compute0,
|
||||
Accum,
|
||||
Bias>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::plus, ElementCompute, ElementCompute,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<
|
||||
Compute1,
|
||||
EVTCompute0,
|
||||
C1>;
|
||||
|
||||
using Compute2 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::plus, ElementOutput, ElementCompute,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using EVTCompute2 = cutlass::epilogue::threadblock::Sm80EVT<
|
||||
Compute2,
|
||||
EVTCompute1,
|
||||
C2>;
|
||||
|
||||
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
|
||||
cute::Stride<int64_t, _1, int64_t> // StrideMNL
|
||||
>;
|
||||
|
||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
|
||||
D,
|
||||
EVTCompute2>;
|
||||
|
||||
using EVTKernelStreamK =
|
||||
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
|
||||
ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EVTD,
|
||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
EVTEpilogueStages
|
||||
>::GemmKernel;
|
||||
|
||||
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options
|
||||
{
|
||||
std::string command_name;
|
||||
bool help;
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
float alpha;
|
||||
float beta;
|
||||
int split_k_factor;
|
||||
int avail_sms;
|
||||
int iterations;
|
||||
bool real;
|
||||
|
||||
cutlass::HostTensor<ElementA, LayoutA> tensor_a;
|
||||
cutlass::HostTensor<ElementB, LayoutB> tensor_b;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_c1;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_c2;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_d;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_ref_d;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_Vector;
|
||||
// cutlass::HostTensor<ElementC, LayoutC> tensor_Tensor;
|
||||
|
||||
Options(std::string command_name) :
|
||||
command_name(command_name),
|
||||
help(false),
|
||||
problem_size({2048, 2048, 2048}),
|
||||
alpha(1.0f),
|
||||
beta(1.0f),
|
||||
split_k_factor(1),
|
||||
avail_sms(-1), // Number of device SMs to use is unlimited
|
||||
real(false),
|
||||
iterations(10000)
|
||||
{}
|
||||
|
||||
bool valid() const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
void parse(int argc, char const **args)
|
||||
{
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
cmd.get_cmd_line_argument("split", split_k_factor);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
real = cmd.check_cmd_line_flag("real");
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const
|
||||
{
|
||||
out
|
||||
<< "Performs a GEMM computation.\n"
|
||||
<< "\n"
|
||||
<< "Options:\n"
|
||||
<< "\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --split=<int> Split-K factor to emulate\n\n"
|
||||
<< " --real If specified, initializes with real values instead of whole numbers. Errors are to be expected.\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options
|
||||
typename DeviceGemmBasic::Arguments args_from_options(
|
||||
const DeviceGemmBasic &device_gemm,
|
||||
const Options &options,
|
||||
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
|
||||
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_c1,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_c2,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_d,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_Vector /*,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_Tensor */
|
||||
)
|
||||
{
|
||||
return typename DeviceGemmBasic::Arguments(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
|
||||
options.problem_size, // problem_size
|
||||
options.split_k_factor, // batch count / splitk slices
|
||||
{ // epilogue parameters
|
||||
ElementAccumulator(options.alpha),
|
||||
ElementAccumulator(options.beta)
|
||||
},
|
||||
tensor_a.device_data(), // ptr_A
|
||||
tensor_b.device_data(), // ptr_B
|
||||
tensor_c1.device_data(), // ptr_C1
|
||||
tensor_c2.device_data(), // ptr_C2
|
||||
tensor_d.device_data(), // ptr_D
|
||||
tensor_Vector.device_data(), // ptr_Vector
|
||||
/* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor
|
||||
options.problem_size.mk().product(), // batch_stride_A
|
||||
options.problem_size.nk().product(), // batch_stride_B
|
||||
options.problem_size.mn().product(), // batch_stride_C1
|
||||
options.problem_size.mn().product(), // batch_stride_C2
|
||||
options.problem_size.mn().product(), // batch_stride_D
|
||||
options.problem_size.mn().product(), // batch_stride_Vector
|
||||
options.problem_size.mn().product(), // batch_stride_Tensor
|
||||
tensor_a.layout().stride(0), // stride_a
|
||||
tensor_b.layout().stride(0), // stride_b
|
||||
tensor_c1.layout().stride(0), // stride_c1
|
||||
tensor_c2.layout().stride(0), // stride_c2
|
||||
tensor_d.layout().stride(0), // stride_d
|
||||
/*tensor_Vector.layout().stride(0)*/0, // stride_Vector
|
||||
/*tensor_Tensor.layout().stride(0)*/0); // stride_Tensor
|
||||
}
|
||||
|
||||
/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
|
||||
typename DeviceGemmStreamK::Arguments args_from_options(
|
||||
const DeviceGemmStreamK &device_gemm,
|
||||
const Options &options,
|
||||
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
|
||||
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_c1,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_c2,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_d,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_Vector/*,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_Tensor*/
|
||||
)
|
||||
{
|
||||
typename EVTD::Arguments callback_args{
|
||||
{
|
||||
{
|
||||
{
|
||||
{}, // Accum
|
||||
{tensor_Vector.device_data(), ElementC(0), {_0{}, _1{}, int32_t(options.problem_size.n())}}, // Bias
|
||||
{} // Compute0
|
||||
}, // EVTCompute0
|
||||
{tensor_c1.device_data(), ElementC(0), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // C1
|
||||
{} // Compute1
|
||||
}, // EVTCompute1
|
||||
{tensor_c2.device_data(), ElementC(0), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // C2
|
||||
{} // Compute2
|
||||
}, // EVTCompute2
|
||||
{tensor_d.device_data(), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // D
|
||||
}; // EVTD
|
||||
|
||||
return typename DeviceGemmStreamK::Arguments(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
|
||||
options.problem_size, // problem_size
|
||||
options.split_k_factor, // batch count / splitk slices
|
||||
callback_args, // argument of EVT callbacks
|
||||
tensor_a.device_data(), // ptr_A
|
||||
tensor_b.device_data(), // ptr_B
|
||||
nullptr, // ptr_C (unused)
|
||||
nullptr, // ptr_D (unused)
|
||||
options.problem_size.mk().product(), // batch_stride_A
|
||||
options.problem_size.nk().product(), // batch_stride_B
|
||||
0, // batch_stride_C (unused)
|
||||
0, // batch_stride_D (unused)
|
||||
tensor_a.layout().stride(0), // stride_a
|
||||
tensor_b.layout().stride(0), // stride_b
|
||||
0, // stride_c (unused)
|
||||
0, // stride_d (unused)
|
||||
options.avail_sms); // avail_sms
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename DeviceGemmT>
|
||||
Result run(std::string description, Options &options)
|
||||
{
|
||||
// Display test description
|
||||
std::cout << std::endl << description << std::endl;
|
||||
|
||||
// Zero-initialize test output matrix D
|
||||
cutlass::reference::host::TensorFill(options.tensor_d.host_view());
|
||||
options.tensor_d.sync_device();
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
DeviceGemmT device_gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT
|
||||
auto arguments = args_from_options(device_gemm, options,
|
||||
options.tensor_a, options.tensor_b, options.tensor_c1, options.tensor_c2, options.tensor_d,
|
||||
options.tensor_Vector/*, options.tensor_Tensor*/);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = DeviceGemmT::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check the problem size is supported or not
|
||||
CUTLASS_CHECK(device_gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(device_gemm());
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
options.tensor_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = cutlass::reference::host::TensorEquals(
|
||||
options.tensor_d.host_view(),
|
||||
options.tensor_ref_d.host_view());
|
||||
|
||||
double err = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||
options.tensor_d.host_view(),
|
||||
options.tensor_ref_d.host_view());
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl;
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(device_gemm());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
// TODO: uncomment when results match
|
||||
//if (!result.passed) {
|
||||
// exit(-1);
|
||||
//}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/// Program entrypoint
|
||||
int main(int argc, const char **argv)
|
||||
{
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Current device must must have compute capability at least 80
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
if (!((props.major * 10 + props.minor) >= 80))
|
||||
{
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Parse commandline options
|
||||
Options options("ampere_streamk_broadcast_gemm");
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::cout <<
|
||||
options.iterations << " timing iterations of " <<
|
||||
options.problem_size.m() << " x " <<
|
||||
options.problem_size.n() << " x " <<
|
||||
options.problem_size.k() << " matrix-matrix multiply" << std::endl;
|
||||
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Initialize GEMM datasets
|
||||
//
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
options.tensor_c1.resize(options.problem_size.mn()); // <- Create matrix C1 with dimensions M x N
|
||||
options.tensor_c2.resize(options.problem_size.mn()); // <- Create matrix C2 with dimensions M x N
|
||||
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
|
||||
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
|
||||
options.tensor_Vector.resize({1, options.problem_size.n()}); // <- Create broadcast vector with dimensions N x 1
|
||||
// options.tensor_Tensor.resize(options.problem_size.mn()); // <- Create T matrix with dimensions M x N
|
||||
|
||||
int _init_bits = options.real ? -1 : 0;
|
||||
|
||||
// Fill matrix A on host with uniform-random data [-2, 2]
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_a.host_view(),
|
||||
1,
|
||||
ElementA(2),
|
||||
ElementA(-2), _init_bits);
|
||||
|
||||
// Fill matrix B on host with uniform-random data [-2, 2]
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_b.host_view(),
|
||||
1,
|
||||
ElementB(2),
|
||||
ElementB(-2), _init_bits);
|
||||
|
||||
// Fill matrix C1 on host with uniform-random data [-2, 2]
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_c1.host_view(),
|
||||
1,
|
||||
ElementC(2),
|
||||
ElementC(-2), _init_bits);
|
||||
|
||||
// Fill matrix C2 on host with uniform-random data [-2, 2]
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_c2.host_view(),
|
||||
1,
|
||||
ElementC(2),
|
||||
ElementC(-2), _init_bits);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_Vector.host_view(),
|
||||
1,
|
||||
ElementC(2),
|
||||
ElementC(-2), _init_bits);
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Copy data from host to GPU
|
||||
options.tensor_a.sync_device();
|
||||
options.tensor_b.sync_device();
|
||||
options.tensor_c1.sync_device();
|
||||
options.tensor_c2.sync_device();
|
||||
options.tensor_Vector.sync_device();
|
||||
// options.tensor_Tensor.sync_device();
|
||||
|
||||
// Zero-initialize reference output matrix D
|
||||
cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view());
|
||||
options.tensor_ref_d.sync_device();
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
options.problem_size,
|
||||
ElementAccumulator(options.alpha),
|
||||
options.tensor_a.device_ref(),
|
||||
options.tensor_b.device_ref(),
|
||||
ElementAccumulator(options.beta),
|
||||
options.tensor_c1.device_ref(),
|
||||
options.tensor_ref_d.device_ref());
|
||||
|
||||
// Wait for kernels to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Copy output data from reference kernel to host for comparison
|
||||
options.tensor_ref_d.sync_host();
|
||||
|
||||
// Add broadcast vector (without multiplier)
|
||||
// This is only possible because BinaryOp is addition, and UnaryOps are identity.
|
||||
// This makes the addition of broadcast vector commutable.
|
||||
/// identity(plus(identity(alpha * (a * b) + v), beta * c)) ==
|
||||
/// alpha * a * b + v + beta * c ==
|
||||
/// (alpha * a * b + beta * c) + v ==
|
||||
/// GEMM(a, b, c) + v
|
||||
// Vector broadcast on host
|
||||
for (int i=0; i < options.problem_size.m(); ++i) {
|
||||
for (int j=0; j < options.problem_size.n(); ++j) {
|
||||
options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_Vector.host_view().ref().at({0, j});
|
||||
options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_c2.host_view().ref().at({i, j});
|
||||
}
|
||||
}
|
||||
|
||||
// Sync back with device just in case
|
||||
options.tensor_ref_d.sync_device();
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
// Test default operation
|
||||
if (options.split_k_factor == 1)
|
||||
{
|
||||
// Compare basic data-parallel version versus StreamK version using default load-balancing heuristics
|
||||
Result basic_dp = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);
|
||||
Result streamk_default = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);
|
||||
|
||||
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));
|
||||
|
||||
// Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1
|
||||
options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing)
|
||||
Result streamk_dp = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);
|
||||
options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)
|
||||
|
||||
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));
|
||||
|
||||
options.split_k_factor++; // Increment splitting factor for next evaluation
|
||||
|
||||
}
|
||||
|
||||
// Show that StreamK can emulate "Split-K" with a tile-splitting factor
|
||||
Result basic_splitk = run<DeviceGemmBasic>(
|
||||
std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
|
||||
options);
|
||||
|
||||
Result streamk_splitk = run<DeviceGemmStreamK>(
|
||||
std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
|
||||
options);
|
||||
|
||||
printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -60,6 +60,7 @@
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
@ -95,12 +96,13 @@ constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // M
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TilesShape = Shape<_128,_128,_32>; // Threadblock-level tile size
|
||||
using TileShape = Shape<_128,_128,_32>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
|
||||
@ -110,15 +112,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TilesShape, ClusterShape,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int>, // Indicates ProblemShape
|
||||
@ -286,10 +293,10 @@ bool initialize_block(
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{}));
|
||||
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{}));
|
||||
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{}));
|
||||
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{}));
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{}));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{}));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{}));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{}));
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
@ -308,11 +315,8 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k},
|
||||
block_A.get(),
|
||||
stride_A,
|
||||
block_B.get(),
|
||||
stride_B,
|
||||
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}}
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
@ -320,7 +324,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.n, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
|
||||
@ -77,15 +77,28 @@
|
||||
will fit in shared memory given the types of operands and the thread block shape, rather than simply using
|
||||
a single default value.
|
||||
|
||||
Note that one does not need to use the CollectiveBuilder to declare CUTLASS 3 kernels; one can still provide
|
||||
every template parameter to the gemm::collective::CollectiveMma. Specifying every template parameter in this
|
||||
manner remains the primary API for using CUTLASS 3 kernels. The CollectiveBuilder is simply meant to be
|
||||
a convenience interface.
|
||||
CUTLASS 3.x provides builders for both collective mainloops and epilogues. The particular implementation of
|
||||
the collective is specified via the schedule tags that corresond to the underlying collective's
|
||||
dispatch policy. `gemm::collective::KernelScheduleAuto` and `epilogue::collective::EpilogueScheduleAuto`
|
||||
are special cases of these schedules that allow the builder to also decide the dispatch policy for you,
|
||||
therefore letting the builder pick the collective specialization.
|
||||
|
||||
Note also that, while the selections made by CollectiveBuilder attempt to maximize performance, this is not
|
||||
a guarantee. Furthermore, the behavior of the CollectiveBuilder when `Auto` parameters are provided is subject
|
||||
to change in future CUTLASS releases -- do not rely on `Auto` if you require a specific scheduling policy and/or
|
||||
stage count to be used.
|
||||
CUTLASS builders make an attempt to pick the best schedule when `Auto` is provided such that the
|
||||
assembled collectives have the best performance, but this is not a guarantee. A user relying on `Auto`
|
||||
may get a free performance upgrade with newer CUTLASS releases in case we can provide more optimized
|
||||
implementations that the builder can transparently assemble for `Auto`. But a user should not rely on
|
||||
`Auto` if they require a specific scheduling policy and/or stage count to be used.
|
||||
|
||||
If a user decides to let the builders pick the collective specialization via `Auto` schedules,
|
||||
they must be used for both mainloop and epilogue alike to ensure compatibility between the
|
||||
chosen collectives. Additionally, if a user chooses to opt in to a specific schedule, non-`Auto`
|
||||
schedules must be used for both mainloop and epilogue builder schedules, and these schedules
|
||||
must be compatible.
|
||||
|
||||
One does not need to use the CollectiveBuilder to declare CUTLASS 3 kernels; one can still provide
|
||||
every template parameter to the `gemm::collective::CollectiveMma`. Specifying every template parameter
|
||||
in this manner remains the primary API for using CUTLASS 3 kernels. `CollectiveBuilder`s are
|
||||
simply meant to be a convenience interface.
|
||||
|
||||
Details of this example
|
||||
-----------------------
|
||||
@ -93,8 +106,15 @@
|
||||
This example also illustrates how CUTLASS 3 GEMMs targeting Hopper automatically support batched GEMMs by simply
|
||||
extending the problem size with an additional tensor rank.
|
||||
|
||||
CUTLASS 3.2 provides initial support for epilogue visitor trees (EVT) for the TMA warp-specialized collective.
|
||||
EVTs allow users to define their own customized epilogue fusion patterns without having to write a new
|
||||
collective epilogue. This is done by representing the fusion as a compute graph, where each node is one of a
|
||||
fundamental set of load, store, or compute operations. These operations are either elementwise for tensor
|
||||
inputs/outputs, broadcasts for vector/scalar inputs, or reductions for vector/scalar outputs.
|
||||
This example shows how users can define their own custom EVT and use it with the CollectiveBuilder.
|
||||
|
||||
Example usage:
|
||||
$ ./examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder \
|
||||
$ ./examples/49_hopper_with_collective_builder/49_collective_builder \
|
||||
--m=2048 --n=2048 --k=2048 --l=2
|
||||
*/
|
||||
|
||||
@ -108,8 +128,10 @@
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
@ -160,7 +182,7 @@ struct Options {
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "49_hopper_gemm_schedules_with_collective_builder\n\n"
|
||||
out << "49_hopper_with_collective_builder\n\n"
|
||||
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
|
||||
<< " performant kernels targeting NVIDIA's Hopper architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
@ -212,16 +234,30 @@ bool initialize_block(
|
||||
// operation builders by specializing the GEMM only on the kernel schedule it will use and the
|
||||
// number of pipeline stages.
|
||||
//
|
||||
// For either option, one can use a special `Auto` type that tells the CollectiveBuilder
|
||||
// One can use a special `Auto` type that tells the CollectiveBuilder
|
||||
// to select an appropriate value on its own. The CollectiveBuilder will attempt to select
|
||||
// values that will result in the most-performant kernel, but this is not a guarantee. Furthermore,
|
||||
// the behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases
|
||||
// configurations that will result in the most-performant kernel, but this is not a guarantee.
|
||||
//
|
||||
// If relying on 'Auto' schedules, all builders must use the 'Auto' schedule to ensure compatiblity.
|
||||
// For example, if `KernelScheduleAuto` is used for the mainloop builder, `EpilogueScheduleAuto` must
|
||||
// be used for the epilogue builder.
|
||||
//
|
||||
// Furthermore, if an override schedule is selected, both epilogue and mainloop schedules must
|
||||
// be specifically opt into a compatible selection.
|
||||
//
|
||||
// Behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases
|
||||
// -- do not rely on `Auto` if you require a specific scheduling policy.
|
||||
template <
|
||||
// Type of kernel schedule to generate
|
||||
class KernelScheduleType = cutlass::gemm::collective::KernelScheduleAuto,
|
||||
class MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto,
|
||||
// Type of epilogue schedule to generate
|
||||
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
// Number of pipeline stages to use
|
||||
class StageCountType = cutlass::gemm::collective::StageCountAuto
|
||||
class StageCountType = cutlass::gemm::collective::StageCountAuto,
|
||||
// Type of tile scheduler to use
|
||||
class TileSchedulerType = cutlass::gemm::PersistentScheduler,
|
||||
// Do we use custom epilogue visitor tree (EVT) fusion
|
||||
bool UseCustomEVT = false
|
||||
>
|
||||
struct ExampleRunner {
|
||||
|
||||
@ -230,27 +266,72 @@ struct ExampleRunner {
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using LayoutD = cutlass::layout::ColumnMajor;
|
||||
|
||||
static constexpr int kAlignmentA = 8;
|
||||
static constexpr int kAlignmentB = 8;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int AlignmentA = 16 / sizeof(ElementA);
|
||||
static constexpr int AlignmentB = 16 / sizeof(ElementB);
|
||||
static constexpr int AlignmentC = 16 / sizeof(ElementC);
|
||||
static constexpr int AlignmentD = 16 / sizeof(ElementD);
|
||||
|
||||
static_assert(not UseCustomEVT ||
|
||||
(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> ||
|
||||
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>),
|
||||
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue");
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
|
||||
// EVTs can be constructed by composing the fundamental load/store/compute visitor operations defined in include/cutlass/epilogue/fusion
|
||||
// For more complex examples of EVT construction please refer to include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp
|
||||
using CustomEVT = // alpha * acc + beta * C
|
||||
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
|
||||
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
cutlass::epilogue::fusion::Sm90SrcFetch, // C
|
||||
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
|
||||
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||
cutlass::epilogue::fusion::Sm90AccFetch // acc
|
||||
>
|
||||
>;
|
||||
|
||||
// A predefined set of fusion operations (implemented with EVT) are supported by the TMA warp-specialized epilogue.
|
||||
// Users can select one of these operations by passing one of the tags defined in include/cutlass/epilogue/fusion/operations.hpp
|
||||
// to the CollectiveBuilder. This frees the user from having to compute additional parameters such as stage counts and copy atoms/layouts.
|
||||
// These tags also provide additional metadata that can be queried at compile time.
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementScalar, RoundStyle>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, kAlignmentA,
|
||||
cutlass::half_t, LayoutB, kAlignmentB,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_2,_1,_1>,
|
||||
StageCountType,
|
||||
KernelScheduleType
|
||||
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
cute::conditional_t<UseCustomEVT, CustomEVT, DefaultOperation>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::gemm::TagToStrideC_t<LayoutD>,
|
||||
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 1, float, float>>;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
Shape<_128,_128,_64>, Shape<_2,_1,_1>,
|
||||
cute::conditional_t<cute::is_same_v<StageCountType, cutlass::gemm::collective::StageCountAuto>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<(int)sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
StageCountType>,
|
||||
MainloopScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
CollectiveEpilogue,
|
||||
TileSchedulerType
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
@ -262,10 +343,10 @@ struct ExampleRunner {
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideA>());
|
||||
using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B<StrideB>());
|
||||
using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideC>());
|
||||
using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideD>());
|
||||
using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t<StrideA>;
|
||||
using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t<StrideB>;
|
||||
using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideC>;
|
||||
using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
@ -281,8 +362,8 @@ struct ExampleRunner {
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementD> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementD> block_ref_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
@ -298,15 +379,15 @@ struct ExampleRunner {
|
||||
|
||||
cutlass::reference::device::GemmComplex(
|
||||
{M, N, K},
|
||||
typename Gemm::EpilogueOutputOp::ElementCompute(alpha),
|
||||
ElementScalar(alpha),
|
||||
ref_A,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ref_B,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
typename Gemm::EpilogueOutputOp::ElementCompute(beta),
|
||||
ElementScalar(beta),
|
||||
ref_C,
|
||||
ref_D,
|
||||
typename Gemm::EpilogueOutputOp::ElementAccumulator(0.f),
|
||||
ElementAccumulator(0),
|
||||
L, // batch_count
|
||||
M * K, // batch_stride_A
|
||||
K * N, // batch_stride_B
|
||||
@ -332,10 +413,10 @@ struct ExampleRunner {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
block_A.reset(M * K * L);
|
||||
block_B.reset(K * N * L);
|
||||
@ -356,14 +437,37 @@ struct ExampleRunner {
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
block_A.get(),
|
||||
stride_A,
|
||||
block_B.get(),
|
||||
stride_B,
|
||||
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}},
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{}, // epilogue.thread
|
||||
block_C.get(), stride_C, block_D.get(), stride_D},
|
||||
hw_info
|
||||
};
|
||||
|
||||
// Custom EVT fusions will have nested unnamed args, the structure of which
|
||||
// can be deduced from the type definition of the EVT.
|
||||
// Each node's arguments has the recursive structure of
|
||||
// {first_child_args, ..., last_child_args, op_args},
|
||||
// For more complex examples of EVT initialization please refer to
|
||||
// include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp
|
||||
if constexpr (UseCustomEVT) {
|
||||
arguments.epilogue.thread =
|
||||
{ // ternary op : beta * C + (alpha * acc)
|
||||
{{options.beta}}, // leaf op+args : beta
|
||||
{}, // leaf op+args : C
|
||||
{ // binary op : alpha * acc
|
||||
{{options.alpha}}, // leaf op+args : alpha
|
||||
{}, // leaf op+args : acc
|
||||
{} // binary args : multiplies
|
||||
}, // end binary op
|
||||
{} // ternary args : multiply_add
|
||||
}; // end ternary op
|
||||
}
|
||||
// Pre-defined fusions will have flat, named args for user-friendlyness
|
||||
else {
|
||||
arguments.epilogue.thread.alpha = options.alpha;
|
||||
arguments.epilogue.thread.beta = options.beta;
|
||||
}
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
@ -477,42 +581,69 @@ int main(int argc, char const **args) {
|
||||
// selected and the maximum number of stages that can fit in shared memory will be selected.
|
||||
//
|
||||
// This example is equivalent to declaring
|
||||
// ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto, cutlass::gemm::collective::StageCountAuto>
|
||||
// ExampleRunner<
|
||||
// cutlass::gemm::collective::KernelScheduleAuto,
|
||||
// cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
// cutlass::gemm::collective::StageCountAuto>
|
||||
// Each of the `Auto` types indicate that the CollectiveBuilder should determine the scheduling policy and
|
||||
// stage count. Note that the behavior of the CollectiveBuilder with `Auto` parameters is subject to change
|
||||
// -- do not rely on `Auto` if you require a specific scheduling policy.
|
||||
// If you opt in to a non-'Auto' schedule, make sure all collectives are built using specific, compatible schedules.
|
||||
ExampleRunner<> auto_schedule_auto_stage_runner;
|
||||
passed = auto_schedule_auto_stage_runner.run(options, hw_info);
|
||||
print_result("Automatically-selected schedule and stage count", passed);
|
||||
|
||||
// One can override the stage count used in the GEMM by replacing cutlass::gemm::collective::StageCountAuto
|
||||
// with the number of stages to use (5 in this case).
|
||||
ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto, _5> auto_schedule_5_stage_runner;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::collective::KernelScheduleAuto,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
_5> auto_schedule_5_stage_runner;
|
||||
|
||||
passed = auto_schedule_5_stage_runner.run(options, hw_info);
|
||||
print_result("Automatically-selected schedule with 5 stages", passed);
|
||||
|
||||
// One can also override the scheduling policy to use. In this case, use the KernelTma scheduling
|
||||
// policy, which specifies that the Hopper TMA feature should be used.
|
||||
ExampleRunner<cutlass::gemm::KernelTma> tma_schedule_auto_stage_runner;
|
||||
// policy, which specifies that the Hopper TMA feature should be used, and we also use an epilogue
|
||||
// that does not use any shared memory.
|
||||
ExampleRunner<cutlass::gemm::KernelTma, cutlass::epilogue::NoSmemWarpSpecialized> tma_schedule_auto_stage_runner;
|
||||
passed = tma_schedule_auto_stage_runner.run(options, hw_info);
|
||||
print_result("TMA schedule with automatically-selected stage count", passed);
|
||||
|
||||
// Here, we override the scheduling policy to use Hopper's TMA feature alongside the warp-specialized
|
||||
// scheduling policy.
|
||||
//
|
||||
// Note that, as of the CUTLASS 3.0 release, this is the default scheduling policy
|
||||
// used by the CollectiveBuilder, so this declaration is equivalent to ExampleRunner<> and
|
||||
// ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto>. However, this default is subject to
|
||||
// change in future releases -- do not rely on `Auto` if you require a specific scheduling policy.
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized> ws_schedule_auto_stage_runner;
|
||||
// scheduling policy, and an epilogue that does not use any shared memory.
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized, cutlass::epilogue::NoSmemWarpSpecialized> ws_schedule_auto_stage_runner;
|
||||
passed = ws_schedule_auto_stage_runner.run(options, hw_info);
|
||||
print_result("Warp-specialized TMA schedule with automatically-selected stage count", passed);
|
||||
|
||||
// Finally, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized
|
||||
// scheduling policy, leveraging persistent thread blocks.
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecializedPersistent> ws_persistent_schedule_auto_stage_runner;
|
||||
passed = ws_persistent_schedule_auto_stage_runner.run(options, hw_info);
|
||||
print_result("Persistent warp-specialized TMA schedule with automatically-selected stage count", passed);
|
||||
// Here, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized
|
||||
// scheduling policy, TMA-based epilogue, leveraging persistent thread blocks.
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong,
|
||||
cutlass::epilogue::TmaWarpSpecialized> ws_pingpong_schedule_auto_stage_runner;
|
||||
passed = ws_pingpong_schedule_auto_stage_runner.run(options, hw_info);
|
||||
print_result("Ping-pong warp-specialized TMA schedule with automatically-selected stage count", passed);
|
||||
|
||||
// Here, we override the scheduling policy to use stream-K problem decomposition atop the cooperative
|
||||
// warp-specialized scheduling policy. This kernel continues to leverage persistent thread blocks
|
||||
// as well aso TMA in both the mainloop and epilogue.
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
cutlass::epilogue::TmaWarpSpecializedCooperative,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::StreamKScheduler> ws_cooperative_stream_k_schedule_auto_stage_runner;
|
||||
passed = ws_cooperative_stream_k_schedule_auto_stage_runner.run(options, hw_info);
|
||||
print_result("Cooperative warp-specialized TMA schedule using stream-K with automatically-selected stage count", passed);
|
||||
|
||||
// Here, we override the fusion operation to use a customized EVT fusion, in addition to the previous schedule overrides
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
cutlass::epilogue::TmaWarpSpecializedCooperative,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::PersistentScheduler,
|
||||
true> ws_cooperative_schedule_auto_stage_custom_evt_runner;
|
||||
passed = ws_cooperative_schedule_auto_stage_custom_evt_runner.run(options, hw_info);
|
||||
print_result("Cooperative warp-specialized TMA schedule using custom epilogue visitor tree with automatically-selected stage count", passed);
|
||||
|
||||
#endif
|
||||
|
||||
@ -0,0 +1,34 @@
|
||||
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
|
||||
cutlass_example_add_executable(
|
||||
49_collective_builder
|
||||
49_collective_builder.cu
|
||||
)
|
||||
@ -34,7 +34,7 @@
|
||||
|
||||
The following example shows how to assemble a custom GEMM kernel that spells out the Collectives
|
||||
directly instead of using a builder and, in the process, instance a more efficient Epilogue
|
||||
(from `cutlass/epilogue/collective/epilogue.hpp`) instead of using the default epilogue.
|
||||
(from `cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp`) instead of using the default epilogue.
|
||||
|
||||
The GemmUniversal API takes 3 main template arguments:
|
||||
(1) the problem shape / extents
|
||||
@ -65,7 +65,7 @@
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/epilogue.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
@ -122,7 +122,7 @@ struct Options {
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "50_hopper_gemm_with_vectorized_epilogue\n\n"
|
||||
out << "50_hopper_gemm_with_epilogue_swizzle\n\n"
|
||||
<< "Hopper GEMM Example with Epilogue Swizzle.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
@ -262,10 +262,10 @@ struct ExampleRunner {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
block_A.reset(M * K * L);
|
||||
block_B.reset(K * N * L);
|
||||
@ -286,11 +286,8 @@ struct ExampleRunner {
|
||||
typename Gemm::GemmKernel::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
block_A.get(),
|
||||
stride_A,
|
||||
block_B.get(),
|
||||
stride_B,
|
||||
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}},
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
|
||||
hw_info
|
||||
};
|
||||
|
||||
@ -443,11 +440,11 @@ int main(int argc, char const **args) {
|
||||
cute::SM90_TMA_LOAD,
|
||||
cute::SM90_TMA_LOAD_MULTICAST>::type;
|
||||
|
||||
using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector<
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape{})), decltype(cute::get<2>(TileShape{}))
|
||||
>());
|
||||
|
||||
using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector<
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape{})), decltype(cute::get<2>(TileShape{}))
|
||||
>());
|
||||
|
||||
@ -494,14 +491,15 @@ int main(int argc, char const **args) {
|
||||
Stride<_16,_1>>,
|
||||
TileShapeS2R>;
|
||||
|
||||
using Epilogue = cutlass::epilogue::collective::Epilogue<
|
||||
using Epilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
||||
cutlass::epilogue::collective::Epilogue<
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::gemm::TagToStrideC_t<LayoutD>,
|
||||
cutlass::epilogue::thread::LinearCombination<int32_t, 1, int32_t, int32_t>,
|
||||
SmemLayout,
|
||||
Copy_Atom<DefaultCopy, ElementAcc>,
|
||||
TiledCopyS2R,
|
||||
Copy_Atom<DefaultCopy, ElementOutput>>;
|
||||
Copy_Atom<DefaultCopy, ElementOutput>>>;
|
||||
|
||||
//
|
||||
// Assembling the GemmKernel
|
||||
|
||||
371
examples/51_hopper_gett/51_hopper_gett.cu
Normal file
371
examples/51_hopper_gett/51_hopper_gett.cu
Normal file
@ -0,0 +1,371 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example of a GETT targeting Hopper tensor cores using the CUTLASS 3.x API.
|
||||
|
||||
CUTLASS has long provided implementations of Generalized Matrix times Matrix (GEMM) kernels.
|
||||
However, a plethora of workloads compute on higher ranked tensors. Products of such tensors,
|
||||
called tensor contractions, can be executed as multiple batched GEMMs, however, they can be
|
||||
further accelerated with kernels that natively operate on these higher ranked tensors to
|
||||
perform Generalized Tensor times Tensor contractions (GETT). CuTe's hierarchical layouts
|
||||
and CUTLASS 3.0's unified micro-kernels make implementation of GETTs trivial. In this example,
|
||||
we show how CUTLASS 3.0, CuTe, and Hopper's TMA feature together can accelerate GETTs while
|
||||
making the process of authoring custom GETT kernels easier than ever before.
|
||||
|
||||
The modes of a tensor that participate in a GETT can be fundamentally grouped into four
|
||||
semantic categories. The contraction modes (or K-modes) only appear in the A and B (left and right)
|
||||
inputs but not in the C output tensor. Row modes (or M-modes) only appear in the left
|
||||
input tensor (A) and the output tensor (C). Column modes (or N-modes) only appear in the
|
||||
right (B) input tensor and the output tensor (C). Batch modes (or L-modes) appear in all
|
||||
input and output tensors. If we fold the many modes of a tensor contraction into these four
|
||||
categories, it would allow us to represent the input and output tensors as rank-3 "matrices"
|
||||
that can be computed upon as if we were computing a batched GEMM!
|
||||
|
||||
This is exactly what CuTe's hierarchical layout representation allows us to do! Instead of having
|
||||
simple integers as strides for these four modes, we can have nested strides for each of these
|
||||
semantic categories that themselves have multiple modes within them -- multi-mode strides!
|
||||
In CUTLASS 3.0, all one has to do to take advantage of this capability is to substitute the
|
||||
required multi-mode strides instead of the default ones provided by gemm::detail::TagToStrideX.
|
||||
|
||||
In the following example, we illustrate how every Hopper GEMM in CUTLASS 3.0 is a GETT in disguise.
|
||||
We begin by defining the four modes detailed above as Row, Col (column), Red (reduction), and
|
||||
Bat (batch) strides, which we then nest for each of the in/out tensors to create our rank-3 stride
|
||||
tuples. Note that although we do not define the problem shape type explicitely, it too remains a
|
||||
rank-4 shape tuple just like any other batched GEMM, but instead with multi-mode shapes for each
|
||||
of the four corresponding multi-modes within it. After this, the same CollectiveMma and
|
||||
CollectiveBuilder we describe in examples 50 and 49 are used to create our kernel type. Nothing
|
||||
else changes from a user's point of view. Note that multi-mode strides do not affect our
|
||||
specializations in any way -- the lexical spelling of our kernels remains the same. The
|
||||
only difference between a CUTLASS 3 batched GEMM and GETT are the instaced CuTe Layouts.
|
||||
|
||||
CollectiveBuilders rely on detecting the static-1 in the stride tuples to determine the major mode,
|
||||
which is what the example demonstrates. However, it is possible to have all modes be dynamic as well
|
||||
if the user assembles a CollectiveMma manually and ensures that the runtime strides are compatible
|
||||
with the static micro-kernel of the collective (TiledMma, TiledCopy, and smem layouts). On the other
|
||||
hand, a user can have more than one static stride too (which need not correspond to the major mode).
|
||||
|
||||
In particular, this example demonstrates a GETT where the 0th M-mode (M0) in A and the 0th K-mode (K0)
|
||||
in B are major. All other combinations of major modes are supported, with the exception of mixed
|
||||
K-major scenarios where both A and B are K-major (e.g. K0 is major in A but K1 is major in B).
|
||||
NVIDIA Hopper architecture's TMA feature makes the predictaion required to implement these complicated
|
||||
kernels trivial, as it is all handled by TMA itself without requiring any programmer effort.
|
||||
|
||||
Example executions, where the stride order defines the major-order (major on the left):
|
||||
51_hopper_gett --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096
|
||||
51_hopper_gett --modeC=l,m,n --modeA=m,l,k --modeB=k,n,l --extents=m:128,n:128,k:128,l:64
|
||||
51_hopper_gett --modeC=m,a,b,p,q,n,l --modeA=m,l,b,k,a --modeB=k,n,p,q,l --extents=m:32,a:32,b:3,n:128,k:128,l:4,p:3,q:3
|
||||
*/
|
||||
|
||||
#include "gett_kernel.cuh"
|
||||
#include "thrust/host_vector.h"
|
||||
#include "thrust/device_vector.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "cutlass/util/gett_commandline.hpp"
|
||||
#include "cutlass/util/reference/device/gett.hpp"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
|
||||
namespace example {
|
||||
|
||||
// Returns true if the left-most value in the tuple is statically known to be 1
|
||||
template<class Stride>
|
||||
constexpr bool
|
||||
is_left_major() {
|
||||
// Account for stride types with and without batch mode and batch modes with static zero stride
|
||||
return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value;
|
||||
}
|
||||
|
||||
// Same as cute::make_int_tuple but inserts a major stride (Int<1>) for the leftmost mode if required
|
||||
template <int Rank, bool IsMajor, class Indexable>
|
||||
static constexpr
|
||||
auto
|
||||
make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) {
|
||||
static_assert(Rank > 1);
|
||||
if constexpr (IsMajor) {
|
||||
return cute::transform(cute::make_seq<Rank>{}, [&](auto i) {
|
||||
if constexpr (i == 0) {
|
||||
return cute::Int<1>{};
|
||||
}
|
||||
else {
|
||||
return i < n ? t[i] : init_default;
|
||||
}
|
||||
});
|
||||
}
|
||||
else {
|
||||
return cute::make_int_tuple<Rank>(t, n, init_default);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace example
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
main(int argc, char const* argv[]) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
using namespace cute;
|
||||
|
||||
if (argc != 5) {
|
||||
std::cout << "Number of command line args must be 4.\n";
|
||||
cutlass::GettCommandLine::print_usage();
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Define the stride types for A, B, C, and D
|
||||
//
|
||||
|
||||
// Stride for A (left input). If reduction mode is major, same must be major in B
|
||||
// For this example, M0 is major in A.
|
||||
using RowModeStridesA = cute::Stride<cute::Int<1>, int64_t, int64_t, int64_t>;
|
||||
using RedModeStridesA = cute::Stride<int64_t, int64_t, int64_t>;
|
||||
using BatModeStridesA = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
|
||||
// Stride for B (right input). If reduction mode is major, same must be major in A
|
||||
// For this example, K0 is major in B.
|
||||
using ColModeStridesB = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
using RedModeStridesB = cute::Stride<cute::Int<1>, int64_t, int64_t>;
|
||||
using BatModeStridesB = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
|
||||
// Strides for output, which can all be dynamic.
|
||||
using RowModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
using ColModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
using BatModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
|
||||
// Assmble our rank-3 multi-mode strides for the in/out tensors
|
||||
using StrideA = cute::Stride<RowModeStridesA, RedModeStridesA, BatModeStridesA>;
|
||||
using StrideB = cute::Stride<ColModeStridesB, RedModeStridesB, BatModeStridesB>;
|
||||
using StrideC = cute::Stride<RowModeStridesC, ColModeStridesC, BatModeStridesC>;
|
||||
|
||||
// Note: C and D share strides here for simplicity.
|
||||
// In general, they need not have the same layout.
|
||||
using StrideD = StrideC;
|
||||
|
||||
//
|
||||
// Define element types for tensors and intermediate values
|
||||
//
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogue = float;
|
||||
|
||||
// The following constexpr values set the max number of modes in each MNKL mode
|
||||
constexpr int MaxRank_M = rank(RowModeStridesA{}); // Max row modes
|
||||
constexpr int MaxRank_N = rank(ColModeStridesB{}); // Max column modes
|
||||
constexpr int MaxRank_K = rank(RedModeStridesA{}); // Max contraction modes
|
||||
constexpr int MaxRank_L = rank(BatModeStridesA{}); // Max batch modes
|
||||
static_assert(rank(RowModeStridesA{}) == rank(RowModeStridesC{}));
|
||||
static_assert(rank(ColModeStridesB{}) == rank(RowModeStridesC{}));
|
||||
static_assert(rank(RedModeStridesA{}) == rank(RedModeStridesB{}));
|
||||
static_assert(rank(BatModeStridesA{}) == rank(BatModeStridesC{}));
|
||||
static_assert(rank(BatModeStridesB{}) == rank(BatModeStridesC{}));
|
||||
|
||||
// Parse command line to get modes, extents, and strides
|
||||
cutlass::GettCommandLine cmd;
|
||||
auto parsed_args = cmd.parse(argc, argv, true);
|
||||
|
||||
auto& m = parsed_args.M;
|
||||
auto& ldAm = parsed_args.ldAm;
|
||||
auto& ldCm = parsed_args.ldCm;
|
||||
int rank_m = int(m.size());
|
||||
|
||||
auto& n = parsed_args.N;
|
||||
auto& ldBn = parsed_args.ldBn;
|
||||
auto& ldCn = parsed_args.ldCn;
|
||||
int rank_n = int(n.size());
|
||||
|
||||
auto& k = parsed_args.K;
|
||||
auto& ldAk = parsed_args.ldAk;
|
||||
auto& ldBk = parsed_args.ldBk;
|
||||
int rank_k = int(k.size());
|
||||
|
||||
auto& l = parsed_args.L;
|
||||
auto& ldAl = parsed_args.ldAl;
|
||||
auto& ldBl = parsed_args.ldBl;
|
||||
auto& ldCl = parsed_args.ldCl;
|
||||
int rank_l = int(l.size());
|
||||
|
||||
if ((rank_m > MaxRank_M) || (rank_n > MaxRank_N) || (rank_k > MaxRank_K) || (rank_l > MaxRank_L)) {
|
||||
std::cerr << "ERROR: Input has more modes than statically configured.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Check that the user input major stride match the static major strides.
|
||||
if (example::is_left_major<RowModeStridesA>() && (ldAm[0] != 1)) {
|
||||
std::cerr << "ERROR: A_M0 is expected to be major, but was not in the provided input!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (example::is_left_major<RedModeStridesA>() && (ldAk[0] != 1)) {
|
||||
std::cerr << "ERROR: A_K0 is expected to be major, but was not in the provided input!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (example::is_left_major<ColModeStridesB>() && (ldBn[0] != 1)) {
|
||||
std::cerr << "ERROR: B_N0 is expected to be major, but was not in the provided input!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (example::is_left_major<RedModeStridesB>() && (ldBk[0] != 1)) {
|
||||
std::cerr << "ERROR: B_K0 is expected to be major, but was not in the provided input!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Convert to `cute::Tuple`s and set up arguments
|
||||
auto M = make_int_tuple<MaxRank_M>(m.data(), rank_m, 1);
|
||||
auto dAm = example::make_stride_tuple<MaxRank_M, example::is_left_major<RowModeStridesA>()>(ldAm.data(), rank_m);
|
||||
auto dCm = example::make_stride_tuple<MaxRank_M, example::is_left_major<RowModeStridesC>()>(ldCm.data(), rank_m);
|
||||
|
||||
auto N = make_int_tuple<MaxRank_N>(n.data(), rank_n, 1);
|
||||
auto dBn = example::make_stride_tuple<MaxRank_N, example::is_left_major<ColModeStridesB>()>(ldBn.data(), rank_n);
|
||||
auto dCn = example::make_stride_tuple<MaxRank_N, example::is_left_major<ColModeStridesC>()>(ldCn.data(), rank_n);
|
||||
|
||||
auto K = make_int_tuple<MaxRank_K>(k.data(), rank_k, 1);
|
||||
auto dAk = example::make_stride_tuple<MaxRank_K, example::is_left_major<RedModeStridesA>()>(ldAk.data(), rank_k);
|
||||
auto dBk = example::make_stride_tuple<MaxRank_K, example::is_left_major<RedModeStridesB>()>(ldBk.data(), rank_k);
|
||||
|
||||
auto L = make_int_tuple<MaxRank_L>(l.data(), rank_l, 1);
|
||||
auto dAl = make_int_tuple<MaxRank_L>(ldAl.data(), rank_l, 0);
|
||||
auto dBl = make_int_tuple<MaxRank_L>(ldBl.data(), rank_l, 0);
|
||||
auto dCl = make_int_tuple<MaxRank_L>(ldCl.data(), rank_l, 0);
|
||||
|
||||
// Concat tuples to turn it into rank-4 problem shape and rank-3 strides, just like GEMM
|
||||
auto problem_shape = make_shape(M, N, K, L);
|
||||
StrideA stride_A = make_stride(dAm, dAk, dAl);
|
||||
StrideB stride_B = make_stride(dBn, dBk, dBl);
|
||||
StrideC stride_C = make_stride(dCm, dCn, dCl);
|
||||
StrideD stride_D = stride_C;
|
||||
|
||||
auto alpha = ElementEpilogue(1.0f);
|
||||
auto beta = ElementEpilogue(1.0f);
|
||||
|
||||
//
|
||||
// Allocate and init tensors
|
||||
//
|
||||
auto M_size = std::accumulate(std::begin(m), std::end(m), 1, std::multiplies<>{});
|
||||
auto N_size = std::accumulate(std::begin(n), std::end(n), 1, std::multiplies<>{});
|
||||
auto K_size = std::accumulate(std::begin(k), std::end(k), 1, std::multiplies<>{});
|
||||
auto L_size = std::accumulate(std::begin(l), std::end(l), 1, std::multiplies<>{});
|
||||
|
||||
thrust::host_vector<ElementA> h_A(M_size * K_size * L_size);
|
||||
thrust::host_vector<ElementB> h_B(N_size * K_size * L_size);
|
||||
thrust::host_vector<ElementC> h_C(M_size * N_size * L_size);
|
||||
thrust::host_vector<ElementD> h_D(M_size * N_size * L_size);
|
||||
|
||||
// Note: the cast to int here is to avoid false-negative ref-checks which can
|
||||
// occur due to floating point arithmetic not being purely associative.
|
||||
for (auto& a : h_A) a = ElementA(int(4*(rand() / double(RAND_MAX)) - 1));
|
||||
for (auto& b : h_B) b = ElementB(int(4*(rand() / double(RAND_MAX)) - 1));
|
||||
for (auto& c : h_C) c = ElementC(int(4*(rand() / double(RAND_MAX)) - 1));
|
||||
for (auto& d : h_D) d = ElementD(-1);
|
||||
|
||||
thrust::device_vector<ElementA> d_A = h_A;
|
||||
thrust::device_vector<ElementB> d_B = h_B;
|
||||
thrust::device_vector<ElementC> d_C = h_C;
|
||||
thrust::device_vector<ElementD> cutlass_result = h_D;
|
||||
thrust::device_vector<ElementD> reference_result = h_D;
|
||||
|
||||
//
|
||||
// Compute GETT
|
||||
//
|
||||
auto status = example::gett_kernel(
|
||||
problem_shape,
|
||||
d_A.data().get(), stride_A,
|
||||
d_B.data().get(), stride_B,
|
||||
ElementAccumulator{},
|
||||
d_C.data().get(), stride_C,
|
||||
cutlass_result.data().get(), stride_D,
|
||||
alpha, beta);
|
||||
|
||||
if (cutlass::Status::kSuccess != status) {
|
||||
std::cerr << "ERROR: GETT operator launch failed.\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto cuda_err = cudaDeviceSynchronize();
|
||||
if (cudaSuccess != cuda_err) {
|
||||
std::cerr << "ERROR: GETT operator execution failed. with error :";
|
||||
std::cerr << cudaGetErrorString(cuda_err) << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::gett(
|
||||
problem_shape,
|
||||
d_A.data().get(), stride_A,
|
||||
d_B.data().get(), stride_B,
|
||||
ElementAccumulator{},
|
||||
d_C.data().get(), stride_C,
|
||||
reference_result.data().get(), stride_D,
|
||||
alpha, beta);
|
||||
|
||||
cuda_err = cudaDeviceSynchronize();
|
||||
if (cudaSuccess != cuda_err) {
|
||||
std::cerr << "ERROR: GETT reference execution failed. with error :";
|
||||
std::cerr << cudaGetErrorString(cuda_err) << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(
|
||||
reference_result.data().get(), cutlass_result.data().get(), cutlass_result.size());
|
||||
if (passed) {
|
||||
std::cout << "GETT verification passed.\n";
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
std::cerr << "ERROR: GETT verification failed! Printing detailed stats.\n";
|
||||
h_D = reference_result;
|
||||
thrust::host_vector<ElementD> h_cutlass_result = cutlass_result;
|
||||
print_relative_error(h_cutlass_result.size(), h_cutlass_result.data(), h_D.data());
|
||||
|
||||
std::cout << "StrideA: "; print(stride_A); std::cout << '\n';
|
||||
std::cout << "StrideB: "; print(stride_B); std::cout << '\n';
|
||||
std::cout << "StrideC: "; print(stride_C); std::cout << '\n';
|
||||
std::cout << "StrideD: "; print(stride_D); std::cout << '\n';
|
||||
return 1;
|
||||
}
|
||||
#else
|
||||
std::cerr << "Unsupported example. Please ensure CUTLASS_ARCH_MMA_SM90_SUPPORTED is defined.\n";
|
||||
return 0;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
}
|
||||
32
examples/51_hopper_gett/CMakeLists.txt
Normal file
32
examples/51_hopper_gett/CMakeLists.txt
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
51_hopper_gett
|
||||
51_hopper_gett.cu
|
||||
)
|
||||
137
examples/51_hopper_gett/gett_kernel.cuh
Normal file
137
examples/51_hopper_gett/gett_kernel.cuh
Normal file
@ -0,0 +1,137 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
namespace example {
|
||||
|
||||
//
|
||||
// GETT entry point
|
||||
//
|
||||
template <
|
||||
class ProblemShapeMNKL,
|
||||
class ElementA,
|
||||
class StrideA,
|
||||
class ElementB,
|
||||
class StrideB,
|
||||
class ElementAccumulator,
|
||||
class ElementC,
|
||||
class StrideC,
|
||||
class ElementD,
|
||||
class StrideD,
|
||||
class ElementEpilogue>
|
||||
cutlass::Status
|
||||
gett_kernel(
|
||||
ProblemShapeMNKL problem_shape_mnkl,
|
||||
ElementA const* ptr_A, StrideA stride_a_mkl,
|
||||
ElementB const* ptr_B, StrideB stride_b_nkl,
|
||||
ElementAccumulator _,
|
||||
ElementC const* ptr_C, StrideC stride_c_mnl,
|
||||
ElementD * ptr_D, StrideD stride_d_mnl,
|
||||
ElementEpilogue alpha, ElementEpilogue beta,
|
||||
cudaStream_t stream = 0) {
|
||||
using namespace cute;
|
||||
|
||||
// TileShape -- GETT configuration
|
||||
// Specify the number of elements to take from each mode
|
||||
// BLK_M = (M0,M1,...) BLK_N = (M0,M1,...) BLK_K = (K0,K1,...)
|
||||
|
||||
// Take 128 from m0, 128 from n0, 64 from k0
|
||||
using TileShape = Shape<Shape<_128>, Shape<_128>, Shape<_64>>;
|
||||
|
||||
/* Other examples:
|
||||
* Take 32 elements from m0 and 4 elements from m1
|
||||
* Take 64 elements from n0 and 2 elements from n1
|
||||
* Take 8 elements from k0 and 8 elements from k1
|
||||
**/
|
||||
// using TileShape = Shape<Shape<_32,_4>, Shape<_64,_2>, Shape<_8,_8>>;
|
||||
|
||||
using EpilogueThreadOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementD, 1, ElementAccumulator, ElementEpilogue, cutlass::epilogue::thread::ScaleType::Default,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
|
||||
|
||||
// No changes are required to the default epilogue
|
||||
using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
||||
cutlass::epilogue::collective::DefaultEpilogue<
|
||||
StrideC,
|
||||
StrideD,
|
||||
EpilogueThreadOp,
|
||||
cutlass::gemm::EpilogueDefault>>;
|
||||
|
||||
// CollectiveMma for GETTs can be built using the CollectiveBuilders
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, StrideA, 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
ElementB, StrideB, 128 / cutlass::sizeof_bits<ElementB>::value,
|
||||
ElementAccumulator,
|
||||
TileShape, Shape<_1,_2,_1>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// The GETT kernel is a composition of a collective mainloop and epilogue, just like any 3.x GEMM
|
||||
using GettKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShapeMNKL,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using GettOperator = cutlass::gemm::device::GemmUniversalAdapter<GettKernel>;
|
||||
|
||||
typename GettOperator::Arguments args {
|
||||
cutlass::gemm::GemmUniversalMode::kBatched,
|
||||
problem_shape_mnkl,
|
||||
{ ptr_A, stride_a_mkl, ptr_B, stride_b_nkl },
|
||||
{ {alpha, beta}, ptr_C, stride_c_mnl, ptr_D, stride_d_mnl }
|
||||
};
|
||||
|
||||
#if CUTLASS_DEBUG_TRACE_LEVEL > 0
|
||||
print("Problem shape:");
|
||||
print("\tM: "); print(cute::get<0>(problem_shape_mnkl)); print("\n");
|
||||
print("\tN: "); print(cute::get<1>(problem_shape_mnkl)); print("\n");
|
||||
print("\tK: "); print(cute::get<2>(problem_shape_mnkl)); print("\n");
|
||||
print("\tL: "); print(cute::get<3>(problem_shape_mnkl)); print("\n");
|
||||
print("TileSape:"); print(TileShape{}); print("\n");
|
||||
#endif
|
||||
|
||||
GettOperator op;
|
||||
return op(args, stream);
|
||||
}
|
||||
|
||||
} // namespace example
|
||||
@ -0,0 +1,687 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example of a Hopper gather+GEMM+scatter kernel fusion.
|
||||
|
||||
This example fuses gather before GEMM and scatter after GEMM into the same
|
||||
GEMM kernel. Gather and scatter operation is controled by an index vector
|
||||
to select rows or columns from A, B, C or D matrices.
|
||||
|
||||
Gather/scatter operations are always performed along a strided dimension
|
||||
in order to preserve vectorized loads/stores. Thus the index vector is
|
||||
applied to rows of row-major matrices and columns of column-major matrices.
|
||||
|
||||
Note that the index vector must contain integers in range [0,X) where
|
||||
X is one of (M,N,K), depending on selected gather dimension. The problem
|
||||
shape given to the GEMM kernel must consist of matrix sizes AFTER gather
|
||||
and BEFORE scatter operations are applied.
|
||||
*/
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <numeric>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
#include "gather_gemm.hpp"
|
||||
#include "gather_kernel.cuh"
|
||||
#include "scatter_epilogue.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace example {
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
cutlass::gemm::BatchedGemmCoord problem_size = {2048, 2048, 2048, 1};
|
||||
int index_size = 1024;
|
||||
int mode = 1; // N-mode gather/scatter by default
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 1.0f;
|
||||
|
||||
bool reference_check = true;
|
||||
int iterations = 20;
|
||||
|
||||
bool valid() const {
|
||||
return problem_size.m() > 0
|
||||
&& problem_size.n() > 0
|
||||
&& problem_size.k() > 0
|
||||
&& problem_size.batch() > 0
|
||||
&& 0 <= mode && mode < 3
|
||||
&& index_size <= problem_size.at(mode)
|
||||
&& iterations > 0;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
cmd.get_cmd_line_argument("batch_size", problem_size.batch());
|
||||
cmd.get_cmd_line_argument("index_size", index_size);
|
||||
|
||||
char const modes[] = {'m', 'n', 'k'};
|
||||
char mode_input = modes[mode];
|
||||
cmd.get_cmd_line_argument("mode", mode_input);
|
||||
mode = int(std::distance(std::begin(modes), std::find(std::begin(modes), std::end(modes), mode_input)));
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("check", reference_check, true);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out <<
|
||||
"52_hopper_gather_scatter_fusion example\n"
|
||||
"\n"
|
||||
" This example uses the CUTLASS Library to fuse gather/scatter of input/output tensors with GEMM.\n"
|
||||
" It validates and benchmarks the fused kernel against an unfused implementation that executes\n"
|
||||
" gather+GEMM+scatter in sequence and writes intermediate (gathered) tensors to memory.\n"
|
||||
" For the unfused implementation two GEMM kernels are considered: default one that uses the same\n"
|
||||
" schedule and instruction set as the fused one, and an optimized one that utilizes advanced\n"
|
||||
" features (such as TMA units) that cannot be used by the fused kernel due to hardware constraints."
|
||||
"\n"
|
||||
"Options:\n"
|
||||
" --help If specified, displays this usage statement.\n"
|
||||
" --m=<int> GEMM M dimension\n"
|
||||
" --n=<int> GEMM N dimension\n"
|
||||
" --k=<int> GEMM K dimension\n"
|
||||
" --batch_size=<int> GEMM batch size\n"
|
||||
" --index_size=<int> Size of N dimension gather/scatter index\n"
|
||||
" --mode=<m,n,k> Gather mode (M, N, or K)\n"
|
||||
" --alpha=<float> GEMM alpha parameter\n"
|
||||
" --beta=<float> GEMM beta parameter\n"
|
||||
" --iterations=<int> Number of profiling iterations to perform.\n"
|
||||
"\n"
|
||||
"Examples:\n"
|
||||
"\n"
|
||||
"$ ./examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion --m=1024 --n=2048 --k=1024 --mode=n --index_size=1024\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class ElementA, class LayoutA, class GatherA,
|
||||
class ElementB, class LayoutB, class GatherB,
|
||||
class ElementC, class LayoutC, class GatherC,
|
||||
class ElementD, class LayoutD, class ScatterD,
|
||||
class ElementAccumulator, class ElementComputeEpilogue>
|
||||
struct ExampleRunner
|
||||
{
|
||||
// Useful aliases
|
||||
|
||||
// Alias to for the epilogue type that supports gather/scatter
|
||||
using Epilogue = cutlass::epilogue::collective::EpilogueGatherScatter<
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::gemm::TagToStrideC_t<LayoutD>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementD, 1,
|
||||
ElementAccumulator, ElementComputeEpilogue,
|
||||
cutlass::epilogue::thread::ScaleType::Default,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, ElementC
|
||||
>,
|
||||
cutlass::gemm::EpilogueDefault,
|
||||
GatherC,
|
||||
ScatterD
|
||||
>;
|
||||
|
||||
// Alias to for the mainloop type
|
||||
using Mainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
ElementB, LayoutB, 128 / cutlass::sizeof_bits<ElementB>::value,
|
||||
ElementAccumulator,
|
||||
Shape<_128,_128,_64>,
|
||||
Shape<_1,_1,_1>,
|
||||
cutlass::gemm::collective::StageCount<5>,
|
||||
cutlass::gemm::KernelMultistage
|
||||
>::CollectiveOp;
|
||||
|
||||
using ProblemShape = Shape<int,int,int,int>;
|
||||
|
||||
using Kernel = cutlass::gemm::kernel::GemmGather<
|
||||
ProblemShape,
|
||||
Mainloop,
|
||||
Epilogue,
|
||||
GatherA,
|
||||
GatherB
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<Kernel>;
|
||||
|
||||
using StrideA = typename Kernel::StrideA;
|
||||
using StrideB = typename Kernel::StrideB;
|
||||
using StrideC = typename Kernel::StrideC;
|
||||
using StrideD = typename Kernel::StrideD;
|
||||
|
||||
static constexpr bool DoGatherA = not cutlass::platform::is_same<GatherA, NoGather>::value;
|
||||
static constexpr bool DoGatherB = not cutlass::platform::is_same<GatherB, NoGather>::value;
|
||||
static constexpr bool DoGatherC = not cutlass::platform::is_same<GatherC, NoGather>::value;
|
||||
static constexpr bool DoScatterD = not cutlass::platform::is_same<ScatterD, NoGather>::value;
|
||||
|
||||
static constexpr bool GatherAonM = DoGatherA && cutlass::platform::is_same<LayoutA,cutlass::layout::RowMajor>::value;
|
||||
static constexpr bool GatherAonK = DoGatherA && cutlass::platform::is_same<LayoutA,cutlass::layout::ColumnMajor>::value;
|
||||
static constexpr bool GatherBonN = DoGatherB && cutlass::platform::is_same<LayoutB,cutlass::layout::ColumnMajor>::value;
|
||||
static constexpr bool GatherBonK = DoGatherB && cutlass::platform::is_same<LayoutB,cutlass::layout::RowMajor>::value;
|
||||
static constexpr bool GatherConM = DoGatherC && cutlass::platform::is_same<LayoutC,cutlass::layout::RowMajor>::value;
|
||||
static constexpr bool GatherConN = DoGatherC && cutlass::platform::is_same<LayoutC,cutlass::layout::ColumnMajor>::value;
|
||||
static constexpr bool ScatterDonM = DoScatterD && cutlass::platform::is_same<LayoutD,cutlass::layout::RowMajor>::value;
|
||||
static constexpr bool ScatterDonN = DoScatterD && cutlass::platform::is_same<LayoutD,cutlass::layout::ColumnMajor>::value;
|
||||
|
||||
static constexpr bool GatherModeM = GatherAonM || GatherConM || ScatterDonM;
|
||||
static constexpr bool GatherModeN = GatherBonN || GatherConN || ScatterDonN;
|
||||
static constexpr bool GatherModeK = GatherAonK || GatherBonK;
|
||||
|
||||
static_assert( GatherModeM && !GatherModeN && !GatherModeK ||
|
||||
!GatherModeM && GatherModeN && !GatherModeK ||
|
||||
!GatherModeM && !GatherModeN && GatherModeK,
|
||||
"Only one gather mode (M, N or K) is supported by example runner");
|
||||
|
||||
// Construct a reference (non-gather) GEMM kernel type
|
||||
|
||||
using MainloopRef = Mainloop;
|
||||
|
||||
using EpilogueRef = typename cutlass::epilogue::collective::DefaultEpilogue<
|
||||
StrideC, StrideD,
|
||||
typename Epilogue::ThreadEpilogueOp,
|
||||
typename Epilogue::EpilogueSchedule
|
||||
>;
|
||||
|
||||
using KernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
MainloopRef,
|
||||
EpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<KernelRef>;
|
||||
|
||||
// Construct an optimized reference GEMM kernel type (using TMA)
|
||||
|
||||
using EpilogueOpt = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
Shape<_128,_128,_64>,
|
||||
Shape<_2,_2,_1>,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementComputeEpilogue,
|
||||
ElementC, LayoutC, 128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementD, LayoutD, 128 / cutlass::sizeof_bits<ElementD>::value,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using MainloopOpt = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
ElementB, LayoutB, 128 / cutlass::sizeof_bits<ElementB>::value,
|
||||
ElementAccumulator,
|
||||
Shape<_128,_128,_64>,
|
||||
Shape<_2,_2,_1>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOpt::SharedStorage)>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using KernelOpt = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
MainloopOpt,
|
||||
EpilogueOpt
|
||||
>;
|
||||
|
||||
using GemmOpt = cutlass::gemm::device::GemmUniversalAdapter<KernelOpt>;
|
||||
|
||||
// Data members
|
||||
|
||||
cutlass::gemm::BatchedGemmCoord problem_size_orig;
|
||||
cutlass::gemm::BatchedGemmCoord problem_size;
|
||||
ProblemShape problem_shape_orig;
|
||||
ProblemShape problem_shape;
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
|
||||
StrideA stride_A_orig;
|
||||
StrideB stride_B_orig;
|
||||
StrideC stride_C_orig;
|
||||
StrideD stride_D_orig;
|
||||
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
|
||||
cutlass::device_memory::allocation<ElementA> tensor_a;
|
||||
cutlass::device_memory::allocation<ElementB> tensor_b;
|
||||
cutlass::device_memory::allocation<ElementC> tensor_c;
|
||||
cutlass::device_memory::allocation<ElementD> tensor_d;
|
||||
|
||||
cutlass::device_memory::allocation<int> gather_indices;
|
||||
|
||||
cutlass::device_memory::allocation<ElementA> tensor_a_gathered;
|
||||
cutlass::device_memory::allocation<ElementB> tensor_b_gathered;
|
||||
cutlass::device_memory::allocation<ElementC> tensor_c_gathered;
|
||||
cutlass::device_memory::allocation<ElementD> tensor_d_gathered;
|
||||
cutlass::device_memory::allocation<ElementD> tensor_d_reference;
|
||||
|
||||
cutlass::gemm::GemmUniversalMode gemm_mode;
|
||||
|
||||
Gemm gemm;
|
||||
typename Gemm::Arguments arguments;
|
||||
cutlass::device_memory::allocation<uint8_t> workspace;
|
||||
|
||||
GemmRef gemm_ref;
|
||||
typename GemmRef::Arguments arguments_ref;
|
||||
cutlass::device_memory::allocation<uint8_t> workspace_ref;
|
||||
|
||||
GemmOpt gemm_opt;
|
||||
typename GemmOpt::Arguments arguments_opt;
|
||||
cutlass::device_memory::allocation<uint8_t> workspace_opt;
|
||||
|
||||
ExampleRunner(Options const &options, cutlass::KernelHardwareInfo const &hw_info)
|
||||
: problem_size_orig(options.problem_size),
|
||||
problem_size(GatherModeM ? options.index_size : problem_size_orig.m(),
|
||||
GatherModeN ? options.index_size : problem_size_orig.n(),
|
||||
GatherModeK ? options.index_size : problem_size_orig.k(),
|
||||
problem_size_orig.batch()),
|
||||
problem_shape_orig(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.k(), problem_size_orig.batch()),
|
||||
problem_shape(problem_size.m(), problem_size.n(), problem_size.k(), problem_size.batch()),
|
||||
hw_info(hw_info),
|
||||
alpha(options.alpha),
|
||||
beta(options.beta),
|
||||
stride_A_orig(cutlass::make_cute_packed_stride(
|
||||
StrideA{}, make_shape(problem_size_orig.m(), problem_size_orig.k(), problem_size_orig.batch()))),
|
||||
stride_B_orig(cutlass::make_cute_packed_stride(
|
||||
StrideB{}, make_shape(problem_size_orig.n(), problem_size_orig.k(), problem_size_orig.batch()))),
|
||||
stride_C_orig(cutlass::make_cute_packed_stride(
|
||||
StrideC{}, make_shape(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.batch()))),
|
||||
stride_D_orig(cutlass::make_cute_packed_stride(
|
||||
StrideD{}, make_shape(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.batch()))),
|
||||
stride_A(cutlass::make_cute_packed_stride(
|
||||
StrideA{}, make_shape(problem_size.m(), problem_size.k(), problem_size.batch()))),
|
||||
stride_B(cutlass::make_cute_packed_stride(
|
||||
StrideB{}, make_shape(problem_size.n(), problem_size.k(), problem_size.batch()))),
|
||||
stride_C(cutlass::make_cute_packed_stride(
|
||||
StrideC{}, make_shape(problem_size.m(), problem_size.n(), problem_size.batch()))),
|
||||
stride_D(cutlass::make_cute_packed_stride(
|
||||
StrideD{}, make_shape(problem_size.m(), problem_size.n(), problem_size.batch()))),
|
||||
tensor_a(problem_size_orig.m() * problem_size_orig.k() * problem_size_orig.batch()),
|
||||
tensor_b(problem_size_orig.k() * problem_size_orig.n() * problem_size_orig.batch()),
|
||||
tensor_c(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()),
|
||||
tensor_d(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()),
|
||||
gather_indices(options.index_size),
|
||||
tensor_a_gathered(problem_size.m() * problem_size.k() * problem_size_orig.batch()),
|
||||
tensor_b_gathered(problem_size.k() * problem_size.n() * problem_size_orig.batch()),
|
||||
tensor_c_gathered(problem_size.m() * problem_size.n() * problem_size_orig.batch()),
|
||||
tensor_d_gathered(problem_size.m() * problem_size.n() * problem_size_orig.batch()),
|
||||
tensor_d_reference(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()),
|
||||
gemm_mode(problem_size.batch() > 1 ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
gemm(),
|
||||
// When constructing arguments for gather/scatter gemm, we must pass stride arguments
|
||||
// made for the original (non-gathered) problem size, because they are used to access
|
||||
// tensors of the original shape. However we still use the reduced (gathered) problem
|
||||
// shape since it corresponds to the logical indexing in reduced size GEMM.
|
||||
arguments{
|
||||
gemm_mode,
|
||||
problem_shape,
|
||||
{
|
||||
tensor_a.get(),
|
||||
stride_A_orig,
|
||||
tensor_b.get(),
|
||||
stride_B_orig
|
||||
},
|
||||
{
|
||||
{ alpha, beta },
|
||||
tensor_c.get(), stride_C_orig,
|
||||
tensor_d.get(), stride_D_orig,
|
||||
typename Epilogue::GatherC {gather_indices.get()},
|
||||
typename Epilogue::ScatterD{gather_indices.get()}
|
||||
},
|
||||
hw_info,
|
||||
typename Kernel::GatherA{gather_indices.get()},
|
||||
typename Kernel::GatherB{gather_indices.get()}
|
||||
},
|
||||
workspace(Gemm::get_workspace_size(arguments)),
|
||||
gemm_ref(),
|
||||
arguments_ref{
|
||||
gemm_mode,
|
||||
problem_shape,
|
||||
{
|
||||
DoGatherA ? tensor_a_gathered.get() : tensor_a.get(),
|
||||
stride_A,
|
||||
DoGatherB ? tensor_b_gathered.get() : tensor_b.get(),
|
||||
stride_B
|
||||
},
|
||||
{
|
||||
{ alpha, beta },
|
||||
DoGatherC ? tensor_c_gathered.get() : tensor_c.get(),
|
||||
stride_C,
|
||||
DoScatterD ? tensor_d_gathered.get() : tensor_d_reference.get(),
|
||||
stride_D
|
||||
},
|
||||
hw_info
|
||||
},
|
||||
workspace_ref(GemmRef::get_workspace_size(arguments_ref)),
|
||||
gemm_opt(),
|
||||
arguments_opt{
|
||||
gemm_mode,
|
||||
problem_shape,
|
||||
{
|
||||
DoGatherA ? tensor_a_gathered.get() : tensor_a.get(),
|
||||
stride_A,
|
||||
DoGatherB ? tensor_b_gathered.get() : tensor_b.get(),
|
||||
stride_B
|
||||
},
|
||||
{
|
||||
{ alpha, beta },
|
||||
DoGatherC ? tensor_c_gathered.get() : tensor_c.get(),
|
||||
stride_C,
|
||||
DoScatterD ? tensor_d_gathered.get() : tensor_d_reference.get(),
|
||||
stride_D
|
||||
},
|
||||
hw_info
|
||||
},
|
||||
workspace_opt(GemmOpt::get_workspace_size(arguments_opt))
|
||||
{
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::device::BlockFillRandomUniform(tensor_a.get(), tensor_a.size(), 1, ElementA(7), ElementA(-8), 0);
|
||||
cutlass::reference::device::BlockFillRandomUniform(tensor_b.get(), tensor_b.size(), 1, ElementB(7), ElementB(-8), 0);
|
||||
cutlass::reference::device::BlockFillRandomUniform(tensor_c.get(), tensor_c.size(), 1, ElementC(7), ElementC(-8), 0);
|
||||
cutlass::reference::device::BlockFillSequential(tensor_d.get(), tensor_d.size(), ElementD(0), ElementD(0));
|
||||
|
||||
// <- Fill gather_indices with unique random integers in range [0,n)
|
||||
int index_range = GatherModeM ? problem_size_orig.m() : (GatherModeN ? problem_size_orig.n() : problem_size_orig.k());
|
||||
std::vector<int> indices(index_range);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
{ // std::random_shuffle was deprecated in C++14 and removed in C++17
|
||||
std::random_device make_seed;
|
||||
std::mt19937 source_of_randomness(make_seed());
|
||||
std::shuffle(indices.begin(), indices.end(), source_of_randomness);
|
||||
}
|
||||
gather_indices.copy_from_host(indices.data());
|
||||
|
||||
auto const gemm_init = [](auto & gemm, auto const & arguments, auto & workspace)
|
||||
{
|
||||
cutlass::Status status = gemm.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
status = gemm.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
};
|
||||
|
||||
gemm_init(gemm, arguments, workspace );
|
||||
gemm_init(gemm_ref, arguments_ref, workspace_ref);
|
||||
gemm_init(gemm_opt, arguments_opt, workspace_opt);
|
||||
}
|
||||
|
||||
void debug_output(std::ostream & os)
|
||||
{
|
||||
auto print_tensor = [](std::ostream &os, char const * name, auto const & data, auto shape, auto stride)
|
||||
{
|
||||
std::vector<remove_cvref_t<decltype(*data.get())>> h_data(data.size());
|
||||
data.copy_to_host(h_data.data());
|
||||
Tensor t = make_tensor(h_data.data(), shape, stride);
|
||||
os << "\n" << name << ": " << std::setw(4) << t << std::endl;
|
||||
};
|
||||
{
|
||||
auto [M,N,K,L] = problem_shape_orig;
|
||||
print_tensor(os, "A", tensor_a, make_shape(M,K,L), stride_A_orig);
|
||||
print_tensor(os, "B", tensor_b, make_shape(N,K,L), stride_B_orig);
|
||||
print_tensor(os, "C", tensor_c, make_shape(M,N,L), stride_C_orig);
|
||||
print_tensor(os, "D", tensor_d, make_shape(M,N,L), stride_D_orig);
|
||||
print_tensor(os, "D reference", tensor_d_reference, make_shape(M,N,L), stride_D_orig);
|
||||
print_tensor(os, "indices", gather_indices, make_shape(gather_indices.size()), make_stride(_1{}));
|
||||
}
|
||||
}
|
||||
|
||||
template<class Gemm2>
|
||||
static void run_gemm(Gemm2 &gemm)
|
||||
{
|
||||
cutlass::Status status = gemm.run();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template<class Gemm2>
|
||||
void run_reference(Gemm2 &gemm)
|
||||
{
|
||||
// Convenience wrapper around calls to separate gather/scatter kernels
|
||||
auto run_gather = [this](auto call, auto const & input, auto & output, auto gather_func, auto batch_size, auto stride)
|
||||
{
|
||||
[[maybe_unused]] auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; });
|
||||
constexpr int I = decltype(idx)::value;
|
||||
call(input.get(),
|
||||
output.get(),
|
||||
gather_func,
|
||||
batch_size,
|
||||
static_cast<int>(input.size() / batch_size),
|
||||
static_cast<int>(output.size() / batch_size),
|
||||
static_cast<int>(get<I>(stride)),
|
||||
hw_info);
|
||||
};
|
||||
|
||||
// Forward calls via lambda to avoid specifying template arguments
|
||||
auto gather_call = [](auto&&... args){ gather(static_cast<decltype(args)&&>(args)...); };
|
||||
// MSVC doesn't count use inside a false "if constexpr" branch.
|
||||
[[maybe_unused]] auto scatter_call = [](auto&&... args){ scatter(static_cast<decltype(args)&&>(args)...); };
|
||||
|
||||
if constexpr (DoGatherA) {
|
||||
run_gather(gather_call, tensor_a, tensor_a_gathered, arguments.gather_A, problem_size.batch(), stride_A);
|
||||
}
|
||||
if constexpr (DoGatherB) {
|
||||
run_gather(gather_call, tensor_b, tensor_b_gathered, arguments.gather_B, problem_size.batch(), stride_B);
|
||||
}
|
||||
if constexpr (DoGatherC) {
|
||||
if (beta != ElementComputeEpilogue(0)) {
|
||||
run_gather(gather_call, tensor_c, tensor_c_gathered, arguments.epilogue.gather_C, problem_size.batch(), stride_C);
|
||||
}
|
||||
}
|
||||
|
||||
run_gemm(gemm);
|
||||
|
||||
if constexpr (DoScatterD) {
|
||||
run_gather(scatter_call, tensor_d_gathered, tensor_d_reference, arguments.epilogue.scatter_D, problem_size.batch(), stride_D);
|
||||
}
|
||||
}
|
||||
|
||||
bool verify()
|
||||
{
|
||||
run_gemm(gemm);
|
||||
run_reference(gemm_ref);
|
||||
cudaDeviceSynchronize();
|
||||
return cutlass::reference::device::BlockCompareEqual(tensor_d.get(), tensor_d_reference.get(), tensor_d.size());
|
||||
}
|
||||
|
||||
bool run(Options const &options)
|
||||
{
|
||||
if (options.reference_check) {
|
||||
if (!verify()) {
|
||||
std::cout << "Failed validation" << std::endl;
|
||||
#if 1
|
||||
debug_output(std::cout);
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
else {
|
||||
std::cout << "Passed validation" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Run profiling loop
|
||||
//
|
||||
|
||||
auto const benchmark = [&](auto name, auto func)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
func();
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
double runtime = timer.elapsed_millis() / double(options.iterations);
|
||||
double gflops = 2 * double(problem_size.product()) / 1e6 / runtime; // Two flops per multiply-add
|
||||
|
||||
std::cout << name << ":\n";
|
||||
std::cout << " Runtime: " << runtime << " ms\n";
|
||||
std::cout << " GFLOPs: " << gflops << "\n";
|
||||
};
|
||||
|
||||
benchmark("Fused", [&](){ run_gemm(gemm); });
|
||||
benchmark("Unfused default", [&](){ run_reference(gemm_ref); });
|
||||
benchmark("Unfused optimized", [&](){ run_reference(gemm_opt); });
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace example
|
||||
|
||||
int main(int argc, const char ** argv) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// CUDA 12 minimum required
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA Toolkit version 12 or later.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (props.major < 9) {
|
||||
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
|
||||
}
|
||||
|
||||
example::Options options;
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << "\n";
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid arguments." << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
bool result = true;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
switch (options.mode) {
|
||||
using namespace example;
|
||||
case 0: {
|
||||
std::cout << "Gather A,C + scatter D on M mode:" << std::endl;
|
||||
using Runner = ExampleRunner<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, IndexedGather<int>, // A
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, NoGather, // B
|
||||
cutlass::half_t, cutlass::layout::RowMajor, IndexedGather<int>, // C
|
||||
cutlass::half_t, cutlass::layout::RowMajor, IndexedGather<int>, // D
|
||||
float, float>;
|
||||
result &= Runner(options, hw_info).run(options);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
std::cout << "Gather B,C + scatter D on N mode:" << std::endl;
|
||||
using Runner = ExampleRunner<
|
||||
cutlass::half_t, cutlass::layout::RowMajor, NoGather, // A
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather<int>, // B
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather<int>, // C
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather<int>, // D
|
||||
float, float>;
|
||||
result &= Runner(options, hw_info).run(options);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
std::cout << "Gather A,B on K mode:" << std::endl;
|
||||
using Runner = ExampleRunner<
|
||||
cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather<int>, // A
|
||||
cutlass::half_t, cutlass::layout::RowMajor, IndexedGather<int>, // B
|
||||
cutlass::half_t, cutlass::layout::RowMajor, NoGather, // C
|
||||
cutlass::half_t, cutlass::layout::RowMajor, NoGather, // D
|
||||
float, float>;
|
||||
result &= Runner(options, hw_info).run(options);
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return result ? EXIT_SUCCESS : EXIT_FAILURE;
|
||||
}
|
||||
32
examples/52_hopper_gather_scatter_fusion/CMakeLists.txt
Normal file
32
examples/52_hopper_gather_scatter_fusion/CMakeLists.txt
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
52_hopper_gather_scatter_fusion
|
||||
52_hopper_gather_scatter_fusion.cu
|
||||
)
|
||||
266
examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp
Normal file
266
examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp
Normal file
@ -0,0 +1,266 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "gather_tensor.hpp"
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class ProblemShape_,
|
||||
class CollectiveMainloop_,
|
||||
class CollectiveEpilogue_,
|
||||
class GatherA_,
|
||||
class GatherB_,
|
||||
class TileScheduler_ = void
|
||||
>
|
||||
class GemmGather
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler = TileScheduler_;
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
static_assert(std::is_same_v<ElementAccumulator, typename CollectiveEpilogue::ElementAccumulator>,
|
||||
"Mainloop and epilogue do not agree on accumulator value type.");
|
||||
|
||||
using GatherA = GatherA_;
|
||||
using GatherB = GatherB_;
|
||||
|
||||
static constexpr int SharedStorageSize = static_cast<int>(cute::max(
|
||||
sizeof(typename CollectiveMainloop::SharedStorage),
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage)));
|
||||
|
||||
static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{});
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
GatherA gather_A{};
|
||||
GatherB gather_B{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params {
|
||||
GemmUniversalMode mode;
|
||||
ProblemShape problem_shape;
|
||||
MainloopParams mainloop;
|
||||
EpilogueParams epilogue;
|
||||
GatherA gather_A{};
|
||||
GatherB gather_B{};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments.
|
||||
static
|
||||
Params
|
||||
to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
return {
|
||||
args.mode,
|
||||
args.problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
|
||||
args.gather_A,
|
||||
args.gather_B
|
||||
};
|
||||
}
|
||||
|
||||
static
|
||||
Status
|
||||
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static
|
||||
bool
|
||||
can_implement(Arguments const& args) {
|
||||
return args.mode == GemmUniversalMode::kGemm or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
}
|
||||
|
||||
static
|
||||
int
|
||||
get_workspace_size(Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr
|
||||
dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
int batch_count = 1;
|
||||
if constexpr (rank(ProblemShape{}) == 4) {
|
||||
batch_count = cute::size<3>(params.problem_shape);
|
||||
}
|
||||
|
||||
return dim3(
|
||||
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
|
||||
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
|
||||
batch_count
|
||||
);
|
||||
}
|
||||
|
||||
static constexpr
|
||||
dim3
|
||||
get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
operator()(Params const& params, char* smem_buf) {
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT(is_static<TileShape>::value);
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
auto M = get<0>(problem_shape_MNKL);
|
||||
auto N = get<1>(problem_shape_MNKL);
|
||||
auto K = get<2>(problem_shape_MNKL);
|
||||
auto L = get<3>(problem_shape_MNKL);
|
||||
|
||||
// Preconditions
|
||||
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
int thread_idx = int(threadIdx.x);
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
auto [m_coord, n_coord, l_coord] = blockIdx;
|
||||
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l)
|
||||
|
||||
// Represent the full tensors
|
||||
Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l)
|
||||
Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l)
|
||||
|
||||
// Get batch slice
|
||||
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)
|
||||
Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k)
|
||||
|
||||
// Slice to get the tiles this thread block is responsible for
|
||||
Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
||||
|
||||
// Compute tile residues for predication
|
||||
auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord
|
||||
auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord
|
||||
auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max
|
||||
auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
|
||||
|
||||
// Allocate the tiled_mma and the accumulators for the (M,N) blk_shape
|
||||
TiledMma tiled_mma;
|
||||
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
clear(accumulators);
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
|
||||
int k_tile_count = size<2>(gA);
|
||||
|
||||
// Perform the collective scoped MMA
|
||||
CollectiveMainloop collective_mma;
|
||||
collective_mma(
|
||||
accumulators,
|
||||
gA,
|
||||
gB,
|
||||
accumulators,
|
||||
k_tile_iter, k_tile_count,
|
||||
residue_mnk,
|
||||
thread_idx,
|
||||
smem_buf
|
||||
);
|
||||
|
||||
// Epilogue and write to gD
|
||||
CollectiveEpilogue epilogue{params.epilogue};
|
||||
epilogue(
|
||||
problem_shape_MNKL,
|
||||
blk_shape,
|
||||
blk_coord_mnkl,
|
||||
accumulators,
|
||||
tiled_mma,
|
||||
residue_mnk,
|
||||
thread_idx,
|
||||
smem_buf
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user