CUTLASS 2.1 (#83)
CUTLASS 2.1 contributes: - BLAS-style host-side API added to CUTLASS Library - Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores - Minor enhancements and bug fixes
This commit is contained in:
14
CHANGELOG.md
14
CHANGELOG.md
@ -1,6 +1,14 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
# CUTLASS 2.0
|
||||
# CUTLASS 2.x
|
||||
|
||||
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
* API to launch compiled kernel instances for GEMM and planar complex GEMM
|
||||
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
|
||||
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
|
||||
* [SDK Examples of Planar Complex GEMMs](/examples/10_planar_complex/planar_complex.cu)
|
||||
* Minor enhancements and bug fixes
|
||||
|
||||
## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19)
|
||||
* Substantially refactored for
|
||||
@ -22,7 +30,7 @@
|
||||
* Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions
|
||||
* Batched GEMM operations
|
||||
* Complex-valued GEMMs
|
||||
* Note: a host compiler supporting C++11 or greater is required.
|
||||
* **Note: a host compiler supporting C++11 or greater is required.**
|
||||
|
||||
# CUTLASS 1.x
|
||||
|
||||
@ -76,7 +84,7 @@
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
|
||||
235
CMakeLists.txt
Normal file → Executable file
235
CMakeLists.txt
Normal file → Executable file
@ -32,15 +32,14 @@ endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.0.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.1.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
|
||||
#
|
||||
# CUTLASS 2.0 requires C++11
|
||||
# CUTLASS 2.x requires C++11
|
||||
#
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
@ -49,7 +48,7 @@ if(CUTLASS_NATIVE_CUDA)
|
||||
set(CMAKE_CUDA_STANDARD 11)
|
||||
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
|
||||
else()
|
||||
string(APPEND NVCC_FLAGS " --std=c++11")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++11)
|
||||
endif()
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
@ -58,13 +57,26 @@ endif()
|
||||
|
||||
message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}")
|
||||
|
||||
if(${CMAKE_PROJECT_NAME} MATCHES ${PROJECT_NAME})
|
||||
set(_CUTLASS_ENABLE_TESTS ON)
|
||||
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
|
||||
|
||||
if(CUTLASS_ENABLE_HEADERS_ONLY)
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT OFF)
|
||||
else()
|
||||
set(_CUTLASS_ENABLE_TESTS OFF)
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT ON)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_TESTS ${_CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS Tests")
|
||||
set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples")
|
||||
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
|
||||
|
||||
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT})
|
||||
else()
|
||||
set(CUTLASS_ENABLE_TESTS_INIT OFF)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
|
||||
|
||||
if (CUTLASS_ENABLE_TESTS)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
|
||||
@ -86,31 +98,22 @@ endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 10.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75)
|
||||
endif()
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if(NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" )
|
||||
message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" )
|
||||
endif()
|
||||
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
|
||||
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
|
||||
|
||||
# Special policy introduced in CMake 3.13
|
||||
if (POLICY CMP0076)
|
||||
cmake_policy(SET CMP0076 NEW)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# check if the configuration is supported
|
||||
if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
|
||||
if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
|
||||
message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
|
||||
endif()
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Configure CMake variables
|
||||
@ -120,11 +123,14 @@ include(GNUInstallDirs)
|
||||
message(STATUS "CUDA Compilation Architectures: ${CUTLASS_NVCC_ARCHS_ENABLED}")
|
||||
|
||||
if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
|
||||
# By default we want to build in Release mode to ensure that we're getting best performance.
|
||||
# By default we want to build in Release mode to ensure that we're getting best performance.
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE)
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "RelWithDebInfo" "Release")
|
||||
endif()
|
||||
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CUTLASS_LIBRARY_DEBUG_POSTFIX ".debug" CACHE STRING "Default postfix value for debug libraries")
|
||||
|
||||
if(WIN32)
|
||||
# On Windows we link against the shared (DLL) runtime. Change gtest settings to match this.
|
||||
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE)
|
||||
@ -132,29 +138,35 @@ endif()
|
||||
|
||||
if (WIN32)
|
||||
# Enable more warnings and treat as errors
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX)
|
||||
|
||||
# Disable warning on Unicode characters
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819)
|
||||
|
||||
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/fp:strict)
|
||||
endif(WIN32)
|
||||
|
||||
if (${CUTLASS_NVCC_VERBOSE})
|
||||
string(APPEND NVCC_FLAGS " -v")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
|
||||
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
set(CUTLASS_ENABLE_F16C ON CACHE BOOL "Enable F16C x86 extensions in host code.")
|
||||
set(CUTLASS_LIBRARY_KERNELS "128x128" CACHE STRING "Comma delimited list of kernel name filters. Default '' means all kernels are enabled.")
|
||||
|
||||
#
|
||||
# CUTLASS generator cmake configuration
|
||||
#
|
||||
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
|
||||
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
|
||||
|
||||
|
||||
# Test Levels L0, L1, L2
|
||||
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
|
||||
set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
|
||||
string(APPEND NVCC_FLAGS " -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL}")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
|
||||
#
|
||||
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
|
||||
@ -166,7 +178,7 @@ else()
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
|
||||
"Enable PTX mma instruction for collective matrix multiply operations.")
|
||||
|
||||
#
|
||||
@ -182,7 +194,7 @@ set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CAC
|
||||
# ...
|
||||
#
|
||||
if(ENABLE_ASAN) # https://github.com/google/sanitizers/wiki/AddressSanitizer
|
||||
string(APPEND NVCC_FLAGS " --compiler-options -fsanitize=address --compiler-options -fno-omit-frame-pointer")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --compiler-options=-fsanitize=address --compiler-options=-fno-omit-frame-pointer)
|
||||
string(APPEND CMAKE_EXE_LINKER_FLAGS " -fsanitize=address")
|
||||
endif()
|
||||
|
||||
@ -192,85 +204,127 @@ endif()
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
|
||||
if(CUTLASS_NVCC_EMBED_CUBIN)
|
||||
string(APPEND NVCC_GENCODE_FLAGS " -gencode=arch=compute_${ARCH},code=sm_${ARCH}")
|
||||
endif()
|
||||
if(CUTLASS_NVCC_EMBED_PTX)
|
||||
string(APPEND NVCC_GENCODE_FLAGS " -gencode=arch=compute_${ARCH},code=compute_${ARCH}")
|
||||
endif()
|
||||
string(APPEND CLANG_FLAGS " --cuda-gpu-arch=sm_${ARCH}")
|
||||
endforeach()
|
||||
|
||||
if(CUTLASS_NVCC_EMBED_PTX)
|
||||
string(APPEND CLANG_FLAGS " --cuda-include-ptx=all")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-include-ptx=all)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
string(APPEND COMMON_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1")
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)
|
||||
endif()
|
||||
|
||||
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
|
||||
# MSVC flow handles caching already, but for other generators we handle it here.
|
||||
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
|
||||
file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR})
|
||||
string(APPEND NVCC_FLAGS " --keep") # --keep-dir may not work with nvcc for some directories.
|
||||
string(APPEND CLANG_FLAGS " -save-temps=${CUTLASS_NVCC_KEEP_DIR}")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep) # --keep-dir may not work with nvcc for some directories.
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR})
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_F16C)
|
||||
string(APPEND COMPILER_FLAGS " -DCUTLASS_ENABLE_F16C=1")
|
||||
if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING)
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_F16C=1)
|
||||
if ((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang"))
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler -mf16c")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-mf16c)
|
||||
elseif((CMAKE_CXX_COMPILER_ID MATCHES "MSVC"))
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /arch:AVX2")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/arch:AVX2)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
string(APPEND NVCC_FLAGS " -lineinfo")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-Wconversion>)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-fno-strict-aliasing>)
|
||||
|
||||
string(APPEND CLANG_FLAGS " -gmlt")
|
||||
|
||||
if (UNIX)
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler -fno-strict-aliasing")
|
||||
# Don't leak lineinfo in release builds
|
||||
if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -gmlt)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo)
|
||||
endif()
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
string(APPEND CLANG_FLAGS " --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}")
|
||||
string(APPEND CLANG_FLAGS " -mllvm -pragma-unroll-threshold=100000")
|
||||
string(APPEND CLANG_FLAGS " -mllvm -unroll-threshold=5000")
|
||||
string(APPEND CLANG_FLAGS " -Wno-unused-command-line-argument")
|
||||
if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" )
|
||||
message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" )
|
||||
endif()
|
||||
|
||||
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
|
||||
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
|
||||
endif()
|
||||
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-pragma-unroll-threshold=100000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-unroll-threshold=5000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)
|
||||
|
||||
string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION})
|
||||
list(GET CUDA_VERSION_PARTS 0 CUDA_VERSION_MAJOR)
|
||||
list(GET CUDA_VERSION_PARTS 1 CUDA_VERSION_MINOR)
|
||||
string(APPEND CLANG_FLAGS " -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR}")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR})
|
||||
|
||||
|
||||
# needed for libcublasLt.so in case it's installed in the same location as libcudart.so
|
||||
# dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags)
|
||||
# Otherwise linker uses RUNPATH and that does not propagate to loaded libs.
|
||||
string(APPEND CLANG_FLAGS " -Wl,--disable-new-dtags")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wl,--disable-new-dtags)
|
||||
|
||||
link_libraries(nvidia::cudart)
|
||||
endif()
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
string(APPEND CMAKE_CXX_FLAGS "${COMMON_FLAGS} ${CLANG_FLAGS}")
|
||||
string(APPEND CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${CLANG_FLAGS_RELEASE}")
|
||||
string(APPEND CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${CLANG_FLAGS_RELWITHDEBINFO}")
|
||||
string(APPEND CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${CLANG_FLAGS_DEBUG}")
|
||||
elseif (CUTLASS_NATIVE_CUDA)
|
||||
string(APPEND CMAKE_CUDA_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS} ${NVCC_GENCODE_FLAGS}")
|
||||
string(APPEND CMAKE_CUDA_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${NVCC_FLAGS_RELEASE}")
|
||||
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${NVCC_FLAGS_RELWITHDEBINFO}")
|
||||
string(APPEND CMAKE_CUDA_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${NVCC_FLAGS_DEBUG}")
|
||||
else()
|
||||
string(APPEND CUDA_NVCC_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS} ${NVCC_GENCODE_FLAGS}")
|
||||
string(APPEND CUDA_NVCC_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${NVCC_FLAGS_RELEASE}")
|
||||
string(APPEND CUDA_NVCC_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${NVCC_FLAGS_RELWITHDEBINFO}")
|
||||
string(APPEND CUDA_NVCC_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${NVCC_FLAGS_DEBUG}")
|
||||
endif()
|
||||
function(cutlass_apply_cuda_gencode_flags TARGET)
|
||||
|
||||
set(NVCC_FLAGS)
|
||||
set(CLANG_FLAGS)
|
||||
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
|
||||
set(CODES)
|
||||
if(CUTLASS_NVCC_EMBED_CUBIN)
|
||||
list(APPEND CODES sm_${ARCH})
|
||||
endif()
|
||||
if(CUTLASS_NVCC_EMBED_PTX)
|
||||
list(APPEND CODES compute_${ARCH})
|
||||
endif()
|
||||
list(JOIN CODES "," CODES_STR)
|
||||
list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}])
|
||||
endforeach()
|
||||
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:${CLANG_FLAGS}>
|
||||
)
|
||||
else()
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:${NVCC_FLAGS}>
|
||||
)
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_apply_standard_compile_options TARGET)
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CUDA_COMPILE_LANGUAGE CXX)
|
||||
set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_CLANG_FLAGS})
|
||||
set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_CLANG_FLAGS_RELEASE})
|
||||
set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO})
|
||||
set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_CLANG_FLAGS_DEBUG})
|
||||
else()
|
||||
set(CUDA_COMPILE_LANGUAGE CUDA)
|
||||
set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_NVCC_FLAGS})
|
||||
set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_NVCC_FLAGS_RELEASE})
|
||||
set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_NVCC_FLAGS_RELWITHDEBINFO})
|
||||
set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG})
|
||||
endif()
|
||||
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:${_FLAGS}>
|
||||
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:RELEASE>:${_FLAGS_RELEASE}>>
|
||||
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:RELWITHDEBINFO>:${_FLAGS_RELWITHDEBINFO}>>
|
||||
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:DEBUG>:${_FLAGS_DEBUG}>>
|
||||
)
|
||||
|
||||
endfunction()
|
||||
|
||||
#
|
||||
# The following items should eventually be pushed into cutlass/CMakeLists.txt
|
||||
@ -324,8 +378,8 @@ if (NOT DEFINED CUTLASS_REVISION)
|
||||
endif()
|
||||
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h
|
||||
@ONLY)
|
||||
|
||||
target_include_directories(
|
||||
@ -338,8 +392,8 @@ target_include_directories(
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY
|
||||
${CUTLASS_INCLUDE_DIR}/
|
||||
DIRECTORY
|
||||
${CUTLASS_INCLUDE_DIR}/
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
@ -399,27 +453,6 @@ endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
|
||||
|
||||
if(CUTLASS_ENABLE_HEADERS_ONLY)
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT OFF)
|
||||
else()
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT ON)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples")
|
||||
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
|
||||
|
||||
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT})
|
||||
else()
|
||||
set(CUTLASS_ENABLE_TESTS_INIT OFF)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
|
||||
|
||||
if(CUTLASS_ENABLE_TOOLS)
|
||||
add_subdirectory(tools)
|
||||
endif()
|
||||
|
||||
159
CUDA.cmake
159
CUDA.cmake
@ -39,23 +39,27 @@ if(CUTLASS_NATIVE_CUDA)
|
||||
|
||||
enable_language(CUDA)
|
||||
|
||||
if(NOT CUDA_VERSION)
|
||||
set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION})
|
||||
endif()
|
||||
if(NOT CUDA_TOOLKIT_ROOT_DIR)
|
||||
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE)
|
||||
endif()
|
||||
|
||||
else()
|
||||
|
||||
find_package(CUDA REQUIRED)
|
||||
# We workaround missing variables with the native flow by also finding the CUDA toolkit the old way.
|
||||
|
||||
endif()
|
||||
if(NOT CMAKE_CUDA_COMPILER_VERSION)
|
||||
set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION})
|
||||
endif()
|
||||
|
||||
if(NOT CUDA_VERSION)
|
||||
set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION})
|
||||
endif()
|
||||
if(NOT CUDA_TOOLKIT_ROOT_DIR)
|
||||
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE)
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 9.2)
|
||||
message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.")
|
||||
endif()
|
||||
|
||||
if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc)
|
||||
message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}")
|
||||
@ -74,7 +78,7 @@ find_library(
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
if(CUDART_LIBRARY)
|
||||
if(NOT TARGET cudart AND CUDART_LIBRARY)
|
||||
|
||||
message(STATUS "CUDART: ${CUDART_LIBRARY}")
|
||||
|
||||
@ -95,6 +99,10 @@ if(CUDART_LIBRARY)
|
||||
${CUDART_LIBRARY}
|
||||
)
|
||||
|
||||
elseif(TARGET cudart)
|
||||
|
||||
message(STATUS "CUDART: Already Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "CUDART: Not Found")
|
||||
@ -116,7 +124,7 @@ find_library(
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
if(CUDA_DRIVER_LIBRARY)
|
||||
if(NOT TARGET cuda_driver AND CUDA_DRIVER_LIBRARY)
|
||||
|
||||
message(STATUS "CUDA Driver: ${CUDA_DRIVER_LIBRARY}")
|
||||
|
||||
@ -137,6 +145,10 @@ if(CUDA_DRIVER_LIBRARY)
|
||||
${CUDA_DRIVER_LIBRARY}
|
||||
)
|
||||
|
||||
elseif(TARGET cuda_driver)
|
||||
|
||||
message(STATUS "CUDA Driver: Already Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "CUDA Driver: Not Found")
|
||||
@ -156,7 +168,7 @@ find_library(
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
if(NVRTC_LIBRARY)
|
||||
if(NOT TARGET nvrtc AND NVRTC_LIBRARY)
|
||||
|
||||
message(STATUS "NVRTC: ${NVRTC_LIBRARY}")
|
||||
|
||||
@ -177,6 +189,10 @@ if(NVRTC_LIBRARY)
|
||||
${NVRTC_LIBRARY}
|
||||
)
|
||||
|
||||
elseif(TARGET nvrtc)
|
||||
|
||||
message(STATUS "NVRTC: Already Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "NVRTC: Not Found")
|
||||
@ -197,48 +213,137 @@ function(cutlass_correct_source_file_language_property)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(cutlass_add_library)
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED ON CACHE BOOL "Enable combined source compilation")
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
|
||||
|
||||
set(options INTERFACE STATIC SHARED OBJECT)
|
||||
set(oneValueArgs)
|
||||
function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs BATCH_SOURCES BATCH_SIZE)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang" OR __INTERFACE)
|
||||
cutlass_correct_source_file_language_property(${ARGN})
|
||||
add_library(${ARGN})
|
||||
else()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
cuda_add_library(${ARGN})
|
||||
if (NOT DEFINED TARGET_ARGS_VAR)
|
||||
message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required")
|
||||
endif()
|
||||
|
||||
if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE)
|
||||
set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE})
|
||||
endif()
|
||||
|
||||
if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1)
|
||||
|
||||
set(CUDA_FILE_ARGS)
|
||||
set(TARGET_SOURCE_ARGS)
|
||||
|
||||
foreach(ARG ${__UNPARSED_ARGUMENTS})
|
||||
if(${ARG} MATCHES ".*\.cu$")
|
||||
list(APPEND CUDA_FILE_ARGS ${ARG})
|
||||
else()
|
||||
list(APPEND TARGET_SOURCE_ARGS ${ARG})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
|
||||
while(NUM_CUDA_FILE_ARGS GREATER 0)
|
||||
list(SUBLIST CUDA_FILE_ARGS 0 ${__BATCH_SIZE} CUDA_FILE_BATCH)
|
||||
string(SHA256 CUDA_FILE_BATCH_HASH "${CUDA_FILE_BATCH}")
|
||||
string(SUBSTRING ${CUDA_FILE_BATCH_HASH} 0 12 CUDA_FILE_BATCH_HASH)
|
||||
set(BATCH_FILE ${CMAKE_CURRENT_BINARY_DIR}/${NAME}.unity.${CUDA_FILE_BATCH_HASH}.cu)
|
||||
message(STATUS "Generating ${BATCH_FILE}")
|
||||
file(WRITE ${BATCH_FILE} "// Unity File - Auto Generated!\n")
|
||||
foreach(CUDA_FILE ${CUDA_FILE_BATCH})
|
||||
get_filename_component(CUDA_FILE_ABS_PATH ${CUDA_FILE} ABSOLUTE)
|
||||
file(APPEND ${BATCH_FILE} "#include \"${CUDA_FILE_ABS_PATH}\"\n")
|
||||
endforeach()
|
||||
list(APPEND TARGET_SOURCE_ARGS ${BATCH_FILE})
|
||||
if (NUM_CUDA_FILE_ARGS LESS_EQUAL __BATCH_SIZE)
|
||||
break()
|
||||
endif()
|
||||
list(SUBLIST CUDA_FILE_ARGS ${__BATCH_SIZE} -1 CUDA_FILE_ARGS)
|
||||
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
|
||||
endwhile()
|
||||
|
||||
else()
|
||||
|
||||
set(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
endif()
|
||||
|
||||
set(${TARGET_ARGS_VAR} ${TARGET_SOURCE_ARGS} PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_add_executable)
|
||||
function(cutlass_add_library NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(oneValueArgs EXPORT_NAME)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
|
||||
cutlass_correct_source_file_language_property(${ARGN})
|
||||
add_executable(${ARGN})
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
add_library(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
else()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
cuda_add_executable(${ARGN})
|
||||
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
endif()
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
|
||||
target_compile_features(
|
||||
${NAME}
|
||||
INTERFACE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
if(__EXPORT_NAME)
|
||||
add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME})
|
||||
set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME})
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_target_sources)
|
||||
function(cutlass_add_executable NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_correct_source_file_language_property(${ARGN})
|
||||
target_sources(${ARGN})
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
else()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
endif()
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
|
||||
target_compile_features(
|
||||
${NAME}
|
||||
INTERFACE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_target_sources NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
target_sources(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
|
||||
endfunction()
|
||||
|
||||
27
README.md
27
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 2.0
|
||||
# CUTLASS 2.1
|
||||
|
||||
_CUTLASS 2.0 - November 2019_
|
||||
_CUTLASS 2.1 - April 2020_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
@ -25,6 +25,14 @@ NVIDIA's Volta and Turing architectures.
|
||||
|
||||
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
# What's New in CUTLASS 2.1
|
||||
|
||||
CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
|
||||
|
||||
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
|
||||
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
|
||||
|
||||
# What's New in CUTLASS 2.0
|
||||
|
||||
CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
|
||||
@ -33,7 +41,7 @@ CUTLASS 2.0 is a substantial refactoring from the previous version, intended to
|
||||
- Robust and durable templates that reliably span the design space
|
||||
- Encapsulated functionality that may be reusable in other contexts
|
||||
|
||||
See the [CHANGELOG](CHANGELOG.md) for more details.
|
||||
**See the [CHANGELOG](CHANGELOG.md) for more details.**
|
||||
|
||||
See the [functionality listing](media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
@ -82,7 +90,7 @@ any Maxwell-, Pascal-, Volta-, or Turing- architecture NVIDIA GPU.
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS 2.0 is described in the following documents and the accompanying
|
||||
CUTLASS 2.1 is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
|
||||
@ -177,7 +185,7 @@ include/ # client applications should target this directory
|
||||
|
||||
### CUTLASS SDK Examples
|
||||
|
||||
CUTLASS SDK examples apply CUTLASS templates to implement basic computations.
|
||||
[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations.
|
||||
|
||||
```
|
||||
examples/
|
||||
@ -198,12 +206,19 @@ examples/
|
||||
07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores
|
||||
|
||||
08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores
|
||||
|
||||
10_planar_complex/ # example demonstrating planar complex GEMM kernels
|
||||
|
||||
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
|
||||
```
|
||||
|
||||
### Tools
|
||||
```
|
||||
tools/
|
||||
library/ # CUTLASS Instance Library - contains instantiations of all supported CUTLASS templates
|
||||
include/
|
||||
cutlass/
|
||||
library/
|
||||
|
||||
profiler/ # CUTLASS Profiler - command-line utility for executing operations in the
|
||||
# CUTLASS Library
|
||||
@ -279,7 +294,7 @@ The official list of CUTLASS developers and contributors is available here: [CON
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
|
||||
22
cuBLAS.cmake
22
cuBLAS.cmake
@ -1,7 +1,8 @@
|
||||
|
||||
message(STATUS "Configuring cublas ...")
|
||||
|
||||
if(DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS)
|
||||
if((DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) OR
|
||||
(DEFINED CUBLAS_ENABLED AND NOT CUBLAS_ENABLED))
|
||||
|
||||
# Don't add cuBLAS if it's defined and false, assume it's not found.
|
||||
|
||||
@ -59,11 +60,13 @@ endif()
|
||||
if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
|
||||
|
||||
if(WIN32)
|
||||
add_library(cublas STATIC IMPORTED)
|
||||
add_library(cublas STATIC IMPORTED GLOBAL)
|
||||
else()
|
||||
add_library(cublas SHARED IMPORTED)
|
||||
add_library(cublas SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cublas ALIAS cublas)
|
||||
|
||||
set_property(
|
||||
TARGET cublas
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
@ -88,23 +91,20 @@ if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
|
||||
${CUBLAS_PATH}/lib/x64
|
||||
/usr/lib/x86_64-linux-gnu)
|
||||
|
||||
if(_CUBLASLT_LIBRARY)
|
||||
if(_CUBLASLT_LIBRARY AND NOT TARGET cublasLt)
|
||||
|
||||
if(WIN32)
|
||||
add_library(cublasLt STATIC IMPORTED)
|
||||
add_library(cublasLt STATIC IMPORTED GLOBAL)
|
||||
else()
|
||||
add_library(cublasLt SHARED IMPORTED)
|
||||
add_library(cublasLt SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
set_property(
|
||||
TARGET cublasLt
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${_CUBLASLT_LIBRARY})
|
||||
|
||||
target_link_libraries(
|
||||
cublas
|
||||
INTERFACE
|
||||
cublasLt)
|
||||
|
||||
add_library(nvidia::cublasLt ALIAS cublasLt)
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
@ -173,14 +173,35 @@ using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
|
||||
EpilogueOp>;
|
||||
|
||||
int main() {
|
||||
|
||||
//
|
||||
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major >= 7)) {
|
||||
std::cerr << "Volta Tensor Ops must be run on a machine with compute capability at least 70."
|
||||
<< std::endl;
|
||||
return 0;
|
||||
return -1;
|
||||
}
|
||||
|
||||
//
|
||||
// Define problem size
|
||||
//
|
||||
|
||||
const int length_m = 5120;
|
||||
const int length_n = 4096;
|
||||
const int length_k = 4096;
|
||||
|
||||
@ -29,7 +29,7 @@ provided by CUTLASS using tensor cores; which we run on a NVIDIA Volta GPU.
|
||||
|
||||
Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing
|
||||
high performance kernels at scale which works for multiple problem sizes with good abstractions is
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
|
||||
multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU
|
||||
easily.
|
||||
|
||||
@ -189,13 +189,27 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
NumStages>;
|
||||
|
||||
int main() {
|
||||
|
||||
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major >= 7)) {
|
||||
std::cerr << "Volta Tensor Ops must be run on a machine with compute capability at least 70."
|
||||
<< std::endl;
|
||||
return 0;
|
||||
return -1;
|
||||
}
|
||||
|
||||
const int length_m = 5120;
|
||||
|
||||
@ -29,7 +29,7 @@ provided by CUTLASS using tensor cores; which we run on a NVIDIA Turing GPU.
|
||||
|
||||
Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing
|
||||
high performance kernels at scale which works for multiple problem sizes with good abstractions is
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
|
||||
multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU
|
||||
easily.
|
||||
|
||||
@ -187,13 +187,28 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
NumStages>;
|
||||
|
||||
int main() {
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major >= 7 && props.minor >= 5)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 75)) {
|
||||
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
return 0;
|
||||
return -1;
|
||||
}
|
||||
|
||||
const int length_m = 5120;
|
||||
|
||||
41
examples/10_planar_complex/CMakeLists.txt
Normal file
41
examples/10_planar_complex/CMakeLists.txt
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
# Planar Complex GEMM example
|
||||
cutlass_example_add_executable(
|
||||
10_planar_complex
|
||||
planar_complex.cu
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# This example depends on the CUTLASS Library
|
||||
#
|
||||
|
||||
target_link_libraries(
|
||||
10_planar_complex
|
||||
PRIVATE
|
||||
cutlass_lib
|
||||
cutlass_tools_util_includes
|
||||
)
|
||||
|
||||
551
examples/10_planar_complex/planar_complex.cu
Normal file
551
examples/10_planar_complex/planar_complex.cu
Normal file
@ -0,0 +1,551 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Planar Complex GEMM
|
||||
|
||||
This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels supporting
|
||||
the batched strided mode.
|
||||
|
||||
These kernels represent complex matrices by storing the real and imaginary parts of the matrix in
|
||||
disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts
|
||||
as either column-major or row-major layouts with a single leading dimension indicating the stride
|
||||
between columns or rows.
|
||||
|
||||
The CUTLASS Library collects multiple template instantiations in a data structure and offers
|
||||
a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures.
|
||||
|
||||
CUTLASS decouples matrix layout from complex transformation, so four possible transformations
|
||||
are possible on the A and B operands:
|
||||
|
||||
n: column-major
|
||||
c: column-major complex conjugate
|
||||
t: row-major
|
||||
h: row-major complex conjugate
|
||||
|
||||
The CUTLASS Library contains many kernel instances specialized for architecture, data type, tile
|
||||
size, and alignment. This can result in long compile times.
|
||||
|
||||
To build strictly the planar complex kernels needed for general application, execute the following
|
||||
CMake command in an empty build directory.
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex
|
||||
|
||||
This builds all planar complex GEMM variants for Volta and Turing architectures.
|
||||
|
||||
To build strictly the kernels needed for this example, an even narrower filter string may be
|
||||
specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for
|
||||
the 'CN' layout configuration (conjugate A operand with both A and B as column-major).
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_f16*cn
|
||||
|
||||
$ make 10_planar_complex
|
||||
|
||||
$ ./examples/10_planar_complex/10_planar_complex --m=2048 --n=1024 --k=512 --batch=10
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/host_tensor_planar_complex.h"
|
||||
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "cutlass/util/reference/device/gemm_planar_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "cutlass/library/handle.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Result(
|
||||
double runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess
|
||||
):
|
||||
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
cutlass::complex<float> alpha;
|
||||
cutlass::complex<float> beta;
|
||||
|
||||
bool reference_check;
|
||||
int iterations;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({1024, 1024, 1024}),
|
||||
batch_count(1),
|
||||
reference_check(true),
|
||||
iterations(20),
|
||||
alpha(1),
|
||||
beta() { }
|
||||
|
||||
bool valid() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
cmd.get_cmd_line_argument("batch", batch_count);
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha.real());
|
||||
cmd.get_cmd_line_argument("alpha_i", alpha.imag());
|
||||
cmd.get_cmd_line_argument("beta", beta.real());
|
||||
cmd.get_cmd_line_argument("beta_i", beta.imag());
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "10_planar_complex example\n\n"
|
||||
<< " This example uses the CUTLASS Library to execute Planar Complex GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --batch <int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i <f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i <f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/10_planar_complex/10_planar_complex --batch=7 --m=1024 --n=512 --k=1024 \\\n"
|
||||
<< " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = problem_size.product() * batch_count * 4;
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Performance test environment for planar complex
|
||||
class TestbedPlanarComplex {
|
||||
public:
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using ElementB = cutlass::half_t;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementC = cutlass::half_t;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using ElementCompute = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
cutlass::library::Handle handle;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
cutlass::DeviceAllocation<ElementA> tensor_A;
|
||||
cutlass::DeviceAllocation<ElementB> tensor_B;
|
||||
cutlass::DeviceAllocation<ElementC> tensor_C;
|
||||
cutlass::DeviceAllocation<ElementC> tensor_D;
|
||||
cutlass::DeviceAllocation<ElementC> tensor_D_ref;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
TestbedPlanarComplex(
|
||||
Options const &options
|
||||
):
|
||||
problem_size(options.problem_size), batch_count(options.batch_count) {
|
||||
|
||||
// Allocate device memory for batched strided GEMM
|
||||
tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2);
|
||||
tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2);
|
||||
tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
|
||||
tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
|
||||
tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
|
||||
}
|
||||
|
||||
void initialize() {
|
||||
|
||||
uint64_t seed = 1073;
|
||||
|
||||
// Use small integers to simplify correctness checking
|
||||
int scope_max = 6;
|
||||
int scope_min = -6;
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
tensor_A.get(), tensor_A.size(), seed, ElementA(scope_max), ElementA(scope_min), 0);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
tensor_B.get(), tensor_B.size(), seed * 2019, ElementB(scope_max), ElementB(scope_min), 0);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
tensor_C.get(), tensor_C.size(), seed * 2020, ElementC(scope_max), ElementC(scope_min), 0);
|
||||
}
|
||||
|
||||
Result profile(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
initialize();
|
||||
|
||||
ElementA *ptr_A = tensor_A.get();
|
||||
ElementB *ptr_B = tensor_B.get();
|
||||
ElementC *ptr_C = tensor_C.get();
|
||||
ElementC *ptr_D = tensor_D.get();
|
||||
|
||||
int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2;
|
||||
int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2;
|
||||
int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
|
||||
int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
|
||||
int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
|
||||
int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
|
||||
int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k();
|
||||
int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n();
|
||||
int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n();
|
||||
int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n();
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Run profiling loop
|
||||
//
|
||||
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
|
||||
//
|
||||
// Execute the planar complex GEMM kernel via the CUTLASS Library's
|
||||
// dispatch routines.
|
||||
//
|
||||
// Note, for planar complex GEMM kernels, all numeric type arguments
|
||||
// specify the data type of the base real types. These are understood to
|
||||
// apply to planar complex representations of matrices in memory and to complex<T>
|
||||
// structures for scalars.
|
||||
//
|
||||
// See tools/library/include/cutlass/library/handle.h for more details.
|
||||
//
|
||||
|
||||
result.status = handle.gemm_planar_complex(
|
||||
problem_size.m(), // GEMM M dimension
|
||||
problem_size.n(), // GEMM N dimension
|
||||
problem_size.k(), // GEMM K dimension
|
||||
|
||||
cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation
|
||||
cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars
|
||||
|
||||
&options.alpha, // Pointer to alpha scalar, of type complex<T>
|
||||
|
||||
cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix
|
||||
cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix
|
||||
cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand
|
||||
ptr_A, // Pointer to real part of A matrix
|
||||
ptr_A + imag_stride_A, // Pointer to imaginary part of A matrix
|
||||
lda, // Leading dimension of real part of A matrix
|
||||
lda, // Leading dimension of imaginary part of A matrix
|
||||
|
||||
cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix
|
||||
cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix
|
||||
cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand
|
||||
ptr_B, // Pointer to real part of B matrix
|
||||
ptr_B + imag_stride_B, // Pointer to imaginary part of B matrix
|
||||
ldb, // Leading dimension of real part of B matrix
|
||||
ldb, // Leading dimension of imaginary part of B matrix
|
||||
|
||||
&options.beta, // Pointer to beta scalar, of type complex<T>
|
||||
|
||||
cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices
|
||||
|
||||
ptr_C, // Pointer to real part of C matrix
|
||||
ptr_C + imag_stride_C, // Pointer to imaginary part of C matrix
|
||||
ldc, // Leading dimension of real part of C matrix
|
||||
ldc, // Leading dimension of imaginary part of C matrix
|
||||
|
||||
ptr_D, // Pointer to real part of D matrix
|
||||
ptr_D + imag_stride_D, // Pointer to imaginary part of D matrix
|
||||
ldd, // Leading dimension of real part of D matrix
|
||||
ldd, // Leading dimension of imaginary part of D matrix
|
||||
|
||||
batch_count, // Number of batched elements
|
||||
|
||||
batch_stride_A, // Stride between batches of real parts of A matrix
|
||||
batch_stride_A, // Stride between batches of imaginary parts of A matrix
|
||||
|
||||
batch_stride_B, // Stride between batches of real parts of B matrix
|
||||
batch_stride_B, // Stride between batches of imaginary parts of B matrix
|
||||
|
||||
batch_stride_C, // Stride between batches of real parts of C matrix
|
||||
batch_stride_C, // Stride between batches of imaginary parts of C matrix
|
||||
|
||||
batch_stride_D, // Stride between batches of real parts of D matrix
|
||||
batch_stride_D // Stride between batches of imaginary parts of D matrix
|
||||
);
|
||||
|
||||
if (result.status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "CUTLASS internal error - configuration not supported" << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
if (handle.get_last_operation()) {
|
||||
std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Compute reference in device code
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
|
||||
result.passed = true;
|
||||
|
||||
for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) {
|
||||
cutlass::reference::device::GemmPlanarComplex<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ElementAccumulator
|
||||
>(
|
||||
problem_size,
|
||||
options.alpha,
|
||||
{tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A},
|
||||
cutlass::ComplexTransform::kConjugate,
|
||||
{tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B},
|
||||
cutlass::ComplexTransform::kNone,
|
||||
options.beta,
|
||||
{tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C},
|
||||
{tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D}
|
||||
);
|
||||
|
||||
ElementC epsilon = 0.1_hf;
|
||||
ElementC nonzero_floor = 0.1_hf;
|
||||
|
||||
result.passed = cutlass::reference::device::BlockCompareRelativelyEqual(
|
||||
tensor_D.get() + idx * batch_stride_D,
|
||||
tensor_D_ref.get() + idx * batch_stride_D,
|
||||
batch_stride_D,
|
||||
epsilon,
|
||||
nonzero_floor
|
||||
);
|
||||
}
|
||||
|
||||
if (result.passed) {
|
||||
std::cout << "Reference check passed." << std::endl;
|
||||
}
|
||||
else {
|
||||
std::cerr << "Error - reference check failed." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
//
|
||||
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
|
||||
//
|
||||
// Volta Tensor Core operations are first available in CUDA 10.1 Toolkit.
|
||||
//
|
||||
// Turing Tensor Core operations are first available in CUDA 10.2 Toolkit.
|
||||
//
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major < 7) {
|
||||
std::cerr << "Volta Tensor Core operations must be run on a machine with compute capability at least 70."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
else if (props.major == 7 && props.minor <= 2) {
|
||||
//
|
||||
// If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example.
|
||||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else if (props.major == 7 && props.minor >= 5) {
|
||||
//
|
||||
// If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example.
|
||||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
TestbedPlanarComplex testbed(options);
|
||||
|
||||
Result result = testbed.profile(options);
|
||||
|
||||
return result.passed ? 0 : -1;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
41
examples/11_planar_complex_array/CMakeLists.txt
Normal file
41
examples/11_planar_complex_array/CMakeLists.txt
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
# Planar Complex Array GEMM example
|
||||
cutlass_example_add_executable(
|
||||
11_planar_complex_array
|
||||
planar_complex_array.cu
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# This example depends on the CUTLASS Library
|
||||
#
|
||||
|
||||
target_link_libraries(
|
||||
11_planar_complex_array
|
||||
PRIVATE
|
||||
cutlass_lib
|
||||
cutlass_tools_util_includes
|
||||
)
|
||||
|
||||
611
examples/11_planar_complex_array/planar_complex_array.cu
Normal file
611
examples/11_planar_complex_array/planar_complex_array.cu
Normal file
@ -0,0 +1,611 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Planar Complex Array Example
|
||||
|
||||
This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels which
|
||||
execute a batch of matrix products, loading problem sizes and matrix base pointers from arrays
|
||||
in global memory.
|
||||
|
||||
These kernels represent complex matrices by storing the real and imaginary parts of the matrix in
|
||||
disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts
|
||||
as either column-major or row-major layouts with a single leading dimension indicating the stride
|
||||
between columns or rows.
|
||||
|
||||
The CUTLASS Library collects multiple template instantiations in a data structure and offers
|
||||
a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures.
|
||||
|
||||
CUTLASS decouples matrix layout from complex transformation, so four possible transformations
|
||||
are possible on the A and B operands:
|
||||
|
||||
n: column-major
|
||||
c: column-major complex conjugate
|
||||
t: row-major
|
||||
h: row-major complex conjugate
|
||||
|
||||
To build strictly the planar complex kernels needed for general application, execute the following
|
||||
CMake command in an empty build directory.
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex
|
||||
|
||||
This builds all planar complex GEMM variants for Volta and Turing architectures.
|
||||
|
||||
To build strictly the kernels needed for this example, an even narrower filter string may be
|
||||
specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for
|
||||
the 'CN' layout configuration (conjugate A operand with both A and B as column-major).
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_array_f16*cn
|
||||
|
||||
$ make 11_planar_complex_array
|
||||
|
||||
$ ./examples/11_planar_complex_array/11_planar_complex_array --m=2048 --n=1024 --k=512 --batch=10
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/host_tensor_planar_complex.h"
|
||||
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "cutlass/util/reference/device/gemm_planar_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "cutlass/library/handle.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Result(
|
||||
double runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess
|
||||
):
|
||||
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
cutlass::complex<float> alpha;
|
||||
cutlass::complex<float> beta;
|
||||
|
||||
bool reference_check;
|
||||
int iterations;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({1024, 1024, 1024}),
|
||||
batch_count(1),
|
||||
reference_check(true),
|
||||
iterations(20),
|
||||
alpha(1),
|
||||
beta() { }
|
||||
|
||||
bool valid() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
cmd.get_cmd_line_argument("batch", batch_count);
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha.real());
|
||||
cmd.get_cmd_line_argument("alpha_i", alpha.imag());
|
||||
cmd.get_cmd_line_argument("beta", beta.real());
|
||||
cmd.get_cmd_line_argument("beta_i", beta.imag());
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "11_planar_complex_array example\n\n"
|
||||
<< " This example uses the CUTLASS Library to execute Planar Complex Array GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --batch <int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i <f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i <f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/11_planar_complex_array/11_planar_complex_array\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = problem_size.product() * batch_count * 4;
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Performance test environment for planar complex
|
||||
class TestbedPlanarComplex {
|
||||
public:
|
||||
|
||||
// Half-precision input and output
|
||||
using Element = cutlass::half_t;
|
||||
|
||||
// Configurations for layouts and internal computation
|
||||
using LayoutA = cutlass::layout::ColumnMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using ElementCompute = float;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
cutlass::library::Handle handle;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
cutlass::DeviceAllocation<Element> tensor_A;
|
||||
cutlass::DeviceAllocation<Element> tensor_B;
|
||||
cutlass::DeviceAllocation<Element> tensor_C;
|
||||
cutlass::DeviceAllocation<Element> tensor_D;
|
||||
cutlass::DeviceAllocation<Element> tensor_D_ref;
|
||||
|
||||
cutlass::DeviceAllocation<void *> ptr_A_real;
|
||||
cutlass::DeviceAllocation<void *> ptr_A_imag;
|
||||
cutlass::DeviceAllocation<void *> ptr_B_real;
|
||||
cutlass::DeviceAllocation<void *> ptr_B_imag;
|
||||
cutlass::DeviceAllocation<void *> ptr_C_real;
|
||||
cutlass::DeviceAllocation<void *> ptr_C_imag;
|
||||
cutlass::DeviceAllocation<void *> ptr_D_real;
|
||||
cutlass::DeviceAllocation<void *> ptr_D_imag;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
TestbedPlanarComplex(
|
||||
Options const &options
|
||||
):
|
||||
problem_size(options.problem_size), batch_count(options.batch_count) {
|
||||
|
||||
// Allocate device memory for batched planar complex GEMM
|
||||
tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2);
|
||||
tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2);
|
||||
tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
|
||||
tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
|
||||
tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
|
||||
|
||||
ptr_A_real.reset(batch_count);
|
||||
ptr_A_imag.reset(batch_count);
|
||||
ptr_B_real.reset(batch_count);
|
||||
ptr_B_imag.reset(batch_count);
|
||||
ptr_C_real.reset(batch_count);
|
||||
ptr_C_imag.reset(batch_count);
|
||||
ptr_D_real.reset(batch_count);
|
||||
ptr_D_imag.reset(batch_count);
|
||||
|
||||
}
|
||||
|
||||
void initialize() {
|
||||
|
||||
uint64_t seed = 1073;
|
||||
|
||||
// Use small integers to simplify correctness checking
|
||||
int scope_max = 6;
|
||||
int scope_min = -6;
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
tensor_A.get(), tensor_A.size(), seed, Element(scope_max), Element(scope_min), 0);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
tensor_B.get(), tensor_B.size(), seed * 2019, Element(scope_max), Element(scope_min), 0);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
tensor_C.get(), tensor_C.size(), seed * 2020, Element(scope_max), Element(scope_min), 0);
|
||||
}
|
||||
|
||||
Result profile(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
initialize();
|
||||
|
||||
Element *ptr_A = tensor_A.get();
|
||||
Element *ptr_B = tensor_B.get();
|
||||
Element *ptr_C = tensor_C.get();
|
||||
Element *ptr_D = tensor_D.get();
|
||||
|
||||
int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2;
|
||||
int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2;
|
||||
int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
|
||||
int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
|
||||
int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
|
||||
int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
|
||||
int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k();
|
||||
int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n();
|
||||
int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n();
|
||||
int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n();
|
||||
|
||||
//
|
||||
// Configure pointers in global memory
|
||||
//
|
||||
|
||||
struct {
|
||||
Element *base;
|
||||
void **ptr_real;
|
||||
void **ptr_imag;
|
||||
int64_t batch_stride;
|
||||
int64_t imag_stride;
|
||||
} tensors[] = {
|
||||
{ tensor_A.get(), ptr_A_real.get(), ptr_A_imag.get(), batch_stride_A, imag_stride_A},
|
||||
{ tensor_B.get(), ptr_B_real.get(), ptr_B_imag.get(), batch_stride_B, imag_stride_B},
|
||||
{ tensor_C.get(), ptr_C_real.get(), ptr_C_imag.get(), batch_stride_C, imag_stride_C},
|
||||
{ tensor_D.get(), ptr_D_real.get(), ptr_D_imag.get(), batch_stride_D, imag_stride_D}
|
||||
};
|
||||
|
||||
for (auto const &tensor : tensors) {
|
||||
for (int idx = 0; idx < batch_count; ++idx) {
|
||||
|
||||
void *ptr_real = tensor.base + idx * tensor.batch_stride;
|
||||
void *ptr_imag = tensor.base + idx * tensor.batch_stride + tensor.imag_stride;
|
||||
|
||||
cudaError_t error = cudaMemcpy(
|
||||
tensor.ptr_real + idx,
|
||||
&ptr_real,
|
||||
sizeof(void *),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
if (error != cudaSuccess) {
|
||||
throw std::runtime_error("Failed to copy pointer to device memory");
|
||||
}
|
||||
|
||||
error = cudaMemcpy(
|
||||
tensor.ptr_imag + idx,
|
||||
&ptr_imag,
|
||||
sizeof(void *),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
if (error != cudaSuccess) {
|
||||
throw std::runtime_error("Failed to copy pointer to device memory");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMM operations
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Run profiling loop
|
||||
//
|
||||
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
|
||||
//
|
||||
// Execute the planar complex array GEMM kernel via the CUTLASS Library's
|
||||
// dispatch routines.
|
||||
//
|
||||
// Note, for planar complex array GEMM kernels, all numeric type arguments
|
||||
// specify the data type of the base real types. These are understood to
|
||||
// apply to planar complex representations of matrices in memory and to complex<T>
|
||||
// structures for scalars.
|
||||
//
|
||||
// See tools/library/include/cutlass/library/handle.h for more details.
|
||||
//
|
||||
|
||||
result.status = handle.gemm_planar_complex_array(
|
||||
|
||||
problem_size.m(), // expected GEMM M dimension
|
||||
problem_size.n(), // expected GEMM N dimension
|
||||
problem_size.k(), // expected GEMM K dimension
|
||||
batch_count, // Number of batched elements
|
||||
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
|
||||
cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation
|
||||
cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars
|
||||
|
||||
&options.alpha, // Pointer to alpha scalar, of type complex<T>
|
||||
|
||||
cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix
|
||||
cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix
|
||||
cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand
|
||||
|
||||
ptr_A_real.get(), // Pointer to array of pointers to real part of A matrix
|
||||
ptr_A_imag.get(), // Pointer to array of pointers to imaginary part of A matrix
|
||||
|
||||
lda, // Leading dimension of real part of A matrix
|
||||
lda, // Leading dimension of imaginary part of A matrix
|
||||
|
||||
cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix
|
||||
cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix
|
||||
cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand
|
||||
|
||||
ptr_B_real.get(), // Pointer to array of pointers to real part of B matrix
|
||||
ptr_B_imag.get(), // Pointer to array of pointers to imaginary part of B matrix
|
||||
|
||||
ldb, // Leading dimension of real part of B matrix
|
||||
ldb, // Leading dimension of imaginary part of B matrix
|
||||
|
||||
&options.beta, // Pointer to beta scalar, of type complex<T>
|
||||
|
||||
cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices
|
||||
|
||||
ptr_C_real.get(), // Pointer to array of pointers to real part of C matrix
|
||||
ptr_C_imag.get(), // Pointer to array of pointers to imaginary part of C matrix
|
||||
|
||||
ldc, // Leading dimension of real part of C matrix
|
||||
ldc, // Leading dimension of imaginary part of C matrix
|
||||
|
||||
ptr_D_real.get(), // Pointer to array of pointers to real part of D matrix
|
||||
ptr_D_imag.get(), // Pointer to array of pointers to imaginary part of D matrix
|
||||
|
||||
ldd, // Leading dimension of real part of D matrix
|
||||
ldd // Leading dimension of imaginary part of D matrix
|
||||
);
|
||||
|
||||
if (result.status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "CUTLASS internal error - configuration not supported" << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMM operations have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
if (handle.get_last_operation()) {
|
||||
std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Compute reference in device code
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
|
||||
result.passed = true;
|
||||
|
||||
for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) {
|
||||
cutlass::reference::device::GemmPlanarComplex<
|
||||
Element, LayoutA,
|
||||
Element, LayoutB,
|
||||
Element, LayoutC,
|
||||
ElementAccumulator
|
||||
>(
|
||||
problem_size,
|
||||
options.alpha,
|
||||
{tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A},
|
||||
cutlass::ComplexTransform::kConjugate,
|
||||
{tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B},
|
||||
cutlass::ComplexTransform::kNone,
|
||||
options.beta,
|
||||
{tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C},
|
||||
{tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D}
|
||||
);
|
||||
|
||||
Element epsilon = 0.1_hf;
|
||||
Element nonzero_floor = 0.1_hf;
|
||||
|
||||
result.passed = cutlass::reference::device::BlockCompareRelativelyEqual(
|
||||
tensor_D.get() + idx * batch_stride_D,
|
||||
tensor_D_ref.get() + idx * batch_stride_D,
|
||||
batch_stride_D,
|
||||
epsilon,
|
||||
nonzero_floor
|
||||
);
|
||||
}
|
||||
|
||||
if (result.passed) {
|
||||
std::cout << "Reference check passed." << std::endl;
|
||||
}
|
||||
else {
|
||||
std::cerr << "Error - reference check failed." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
//
|
||||
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
|
||||
//
|
||||
// Volta Tensor Core operations are first available in CUDA 10.1 Toolkit.
|
||||
//
|
||||
// Turing Tensor Core operations are first available in CUDA 10.2 Toolkit.
|
||||
//
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major < 7) {
|
||||
std::cerr << "Tensor Core operations must be run on a machine with compute capability at least 70."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
else if (props.major == 7 && props.minor <= 2) {
|
||||
//
|
||||
// If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example.
|
||||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else if (props.major == 7 && props.minor >= 5) {
|
||||
//
|
||||
// If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example.
|
||||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
TestbedPlanarComplex testbed(options);
|
||||
|
||||
Result result = testbed.profile(options);
|
||||
|
||||
return result.passed ? 0 : -1;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -22,16 +22,14 @@
|
||||
|
||||
set(CUTLASS_EXAMPLES_COMMON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/common)
|
||||
|
||||
function(cutlass_example_add_executable)
|
||||
function(cutlass_example_add_executable NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_add_executable(${__UNPARSED_ARGUMENTS})
|
||||
|
||||
list(GET __UNPARSED_ARGUMENTS 0 NAME)
|
||||
cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
target_link_libraries(
|
||||
${NAME}
|
||||
@ -59,7 +57,10 @@ foreach(EXAMPLE
|
||||
05_batched_gemm
|
||||
06_splitK_gemm
|
||||
07_volta_tensorop_gemm
|
||||
08_turing_tensorop_gemm)
|
||||
08_turing_tensorop_gemm
|
||||
10_planar_complex
|
||||
11_planar_complex_array
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
add_dependencies(cutlass_examples ${EXAMPLE})
|
||||
|
||||
@ -46,61 +46,68 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Specializations
|
||||
// Determine the appropriate way to target PTX's "ldmatrix" instruction.
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ == 2)
|
||||
#define CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED 1
|
||||
#else
|
||||
#define CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED 0
|
||||
#endif
|
||||
|
||||
#if ! defined(CUDA_NVVM_GET_SHARED_POINTER_ENABLED)
|
||||
#define CUDA_NVVM_GET_SHARED_POINTER_ENABLED (CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED)
|
||||
#endif
|
||||
|
||||
#if ! defined(CUDA_LDMATRIX_SUPPORTED)
|
||||
#define CUDA_LDMATRIX_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 2))
|
||||
#endif
|
||||
|
||||
#if ! defined(CUDA_LDMATRIX_ENABLED)
|
||||
#define CUDA_LDMATRIX_ENABLED (CUDA_LDMATRIX_SUPPORTED)
|
||||
#define CUDA_LDMATRIX_ENABLED CUDA_LDMATRIX_SUPPORTED
|
||||
#endif
|
||||
|
||||
#if (CUDA_LDMATRIX_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
|
||||
#if CUDA_LDMATRIX_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
|
||||
#define CUDA_LDMATRIX_ACTIVATED 1
|
||||
#else
|
||||
#define CUDA_LDMATRIX_ACTIVATED 0
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_GET_SMEM_POINTER)
|
||||
// Use the existing implementation
|
||||
#elif CUDA_NVVM_GET_SHARED_POINTER_ENABLED
|
||||
#if ! defined(NVVM_GET_SMEM_POINTER)
|
||||
#define NVVM_GET_SMEM_POINTER
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED)
|
||||
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1))
|
||||
#endif
|
||||
|
||||
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED)
|
||||
#define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED
|
||||
#endif
|
||||
|
||||
#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED
|
||||
extern "C" {
|
||||
//
|
||||
// This NVVM intrinsic is subject to change in future versions of CUDA.
|
||||
// Clients should not call it directly. Rather, they should use the
|
||||
// cutlass::arch::ldsm<>() template.
|
||||
//
|
||||
__device__ uint32_t __nvvm_get_smem_pointer(void*);
|
||||
//
|
||||
// This NVVM intrinsic is subject to change in future versions of CUDA.
|
||||
// Clients should not call it directly. Rather, they should use the
|
||||
// cutlass::arch::ldsm<>() template.
|
||||
//
|
||||
__device__ uint32_t __nvvm_get_smem_pointer(void *);
|
||||
}
|
||||
#endif
|
||||
#define CUTLASS_GET_SMEM_POINTER(ptr) __nvvm_get_smem_pointer((void*)ptr)
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED
|
||||
|
||||
/// CUTLASS helper to get SMEM pointer
|
||||
inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) {
|
||||
return __nvvm_get_smem_pointer(const_cast<void *>(ptr));
|
||||
}
|
||||
|
||||
/// CUTLASS helper to get SMEM pointer
|
||||
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
|
||||
return __nvvm_get_smem_pointer(ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
inline __device__ void ldsm<layout::RowMajor, 1>(
|
||||
Array<unsigned, 1> & D,
|
||||
void const* ptr) {
|
||||
|
||||
#if CUDA_LDMATRIX_ACTIVATED
|
||||
#if defined(CUDA_LDMATRIX_ACTIVATED)
|
||||
|
||||
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
|
||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||
|
||||
int x;
|
||||
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr));
|
||||
@ -120,9 +127,9 @@ inline __device__ void ldsm<layout::RowMajor, 2>(
|
||||
Array<unsigned, 2> & D,
|
||||
void const* ptr) {
|
||||
|
||||
#if CUDA_LDMATRIX_ACTIVATED
|
||||
#if defined(CUDA_LDMATRIX_ACTIVATED)
|
||||
|
||||
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
|
||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||
|
||||
int x, y;
|
||||
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr));
|
||||
@ -142,9 +149,9 @@ inline __device__ void ldsm<layout::RowMajor, 4>(
|
||||
Array<unsigned, 4> & D,
|
||||
void const* ptr) {
|
||||
|
||||
#if CUDA_LDMATRIX_ACTIVATED
|
||||
#if defined(CUDA_LDMATRIX_ACTIVATED)
|
||||
|
||||
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
|
||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||
|
||||
int x, y, z, w;
|
||||
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr));
|
||||
@ -167,9 +174,10 @@ template <>
|
||||
inline __device__ void ldsm<layout::ColumnMajor, 1>(
|
||||
Array<unsigned, 1> & D,
|
||||
void const* ptr) {
|
||||
|
||||
#if CUDA_LDMATRIX_ACTIVATED
|
||||
|
||||
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
|
||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||
|
||||
int x;
|
||||
asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr));
|
||||
@ -189,9 +197,9 @@ inline __device__ void ldsm<layout::ColumnMajor, 2>(
|
||||
Array<unsigned, 2> & D,
|
||||
void const* ptr) {
|
||||
|
||||
#if CUDA_LDMATRIX_ACTIVATED
|
||||
#if defined(CUDA_LDMATRIX_ACTIVATED)
|
||||
|
||||
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
|
||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||
|
||||
int x, y;
|
||||
asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr));
|
||||
@ -211,9 +219,9 @@ inline __device__ void ldsm<layout::ColumnMajor, 4>(
|
||||
Array<unsigned, 4> & D,
|
||||
void const* ptr) {
|
||||
|
||||
#if CUDA_LDMATRIX_ACTIVATED
|
||||
#if defined(CUDA_LDMATRIX_ACTIVATED)
|
||||
|
||||
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
|
||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||
|
||||
int x, y, z, w;
|
||||
asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr));
|
||||
|
||||
@ -30,7 +30,9 @@
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -49,6 +51,11 @@ struct OpMultiplyAddSaturate;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tag indicating the complex multiply-add operation
|
||||
struct OpMultiplyAddComplex;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tag indicating the inner product is defined by (XOR, POPC)
|
||||
struct OpXorPopc;
|
||||
|
||||
|
||||
@ -27,7 +27,11 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "mma.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
@ -84,6 +88,7 @@ struct Mma<
|
||||
using FragmentC = Array<half_t, 8>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -139,6 +144,7 @@ struct Mma<
|
||||
using FragmentC = Array<half_t, 8>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -194,6 +200,7 @@ struct Mma<
|
||||
using FragmentC = Array<half_t, 8>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -249,6 +256,7 @@ struct Mma<
|
||||
using FragmentC = Array<half_t, 8>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -310,6 +318,7 @@ struct Mma<
|
||||
using FragmentC = Array<float, 8>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
/// Multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -385,6 +394,7 @@ struct Mma<
|
||||
using FragmentC = Array<float, 8>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
/// Multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -460,6 +470,7 @@ struct Mma<
|
||||
using FragmentC = Array<float, 8>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
/// Multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -535,6 +546,7 @@ struct Mma<
|
||||
using FragmentC = Array<float, 8>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
/// Multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -28,7 +28,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/arch/wmma.h"
|
||||
|
||||
@ -93,6 +97,7 @@ struct Mma<
|
||||
using FragmentC = Array<half_t, 4>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
@ -154,6 +159,7 @@ struct Mma<
|
||||
using FragmentC = Array<float, 4>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -215,6 +221,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -271,6 +278,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -327,6 +335,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -384,6 +393,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -446,6 +456,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAddSaturate;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -502,6 +513,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAddSaturate;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -558,6 +570,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAddSaturate;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -614,6 +627,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAddSaturate;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -676,6 +690,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -732,6 +747,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -788,6 +804,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -844,6 +861,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -906,6 +924,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAddSaturate;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -962,6 +981,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAddSaturate;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1018,6 +1038,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAddSaturate;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1074,6 +1095,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpMultiplyAddSaturate;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1136,6 +1158,7 @@ struct Mma<
|
||||
using FragmentC = Array<int, 2>;
|
||||
|
||||
using Operator = OpXorPopc;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
/// Computes multiply-add
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -68,24 +68,6 @@
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// MemoryKind class (Shared vs. Global memory)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
enum class MemoryKind {
|
||||
kShared, // Data resides in shared memory
|
||||
kGlobal // Data resides in global memory
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// WarpParams holds architecture-specific constants
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct WarpParams {
|
||||
static int const kThreadsPerWarp = 32;
|
||||
static int const kQuadsPerWarp = 8;
|
||||
static int const kThreadsPerQuad = 4;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Statically maps cutlass data types => nvcuda::wmma data types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -196,7 +178,6 @@ template <
|
||||
struct Wmma;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
@ -1,105 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates exposing warp matrix multiply-add (WMMA) operations
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/wmma.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
///
|
||||
/// WMMA structures to enclose * PTX * instruction string
|
||||
///
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// WMMA PTX string load for A, B, and C matrices
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
|
||||
typename Element_, ///< Data type of elements
|
||||
typename Layout_, ///< Layout of matrix (concept: MatrixLayout)
|
||||
MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory
|
||||
>
|
||||
struct PtxWmmaLoadA;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
|
||||
typename Element_, ///< Data type of elements
|
||||
typename Layout_, ///< Layout of matrix (concept: MatrixLayout)
|
||||
MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory
|
||||
>
|
||||
struct PtxWmmaLoadB;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
|
||||
typename Element_, ///< Data type of elements
|
||||
typename Layout_, ///< Layout of matrix (concept: MatrixLayout)
|
||||
MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory
|
||||
>
|
||||
struct PtxWmmaLoadC;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// WMMA Matrix multiply-add operation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
|
||||
typename ElementA_, ///< Data type of A elements
|
||||
typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout)
|
||||
typename ElementB_, ///< Data type of B elements
|
||||
typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout)
|
||||
typename ElementC_, ///< Element type of C matrix
|
||||
typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout)
|
||||
typename Operator = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc)
|
||||
>
|
||||
struct PtxWmma;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// WMMA store for matrix D
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
|
||||
typename Element_, ///< Data type of elements
|
||||
typename Layout_, ///< Layout of matrix (concept: MatrixLayout)
|
||||
MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory
|
||||
>
|
||||
struct PtxWmmaStoreD;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -28,7 +28,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -68,6 +72,7 @@ struct Wmma<
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
// check supported wmma shape for the given multiplicand data types
|
||||
static_assert(
|
||||
|
||||
@ -28,7 +28,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -65,6 +69,7 @@ struct Wmma<
|
||||
using ElementC = int32_t;
|
||||
using LayoutC = LayoutC_;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm72;
|
||||
|
||||
// check supported wmma shape for the given multiplicand data types
|
||||
static_assert(
|
||||
@ -145,6 +150,7 @@ struct Wmma<
|
||||
using ElementC = int32_t;
|
||||
using LayoutC = LayoutC_;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm72;
|
||||
|
||||
// check supported wmma shape for the given multiplicand data types
|
||||
static_assert(
|
||||
|
||||
@ -28,7 +28,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -65,6 +69,7 @@ struct Wmma<
|
||||
using ElementC = int32_t;
|
||||
using LayoutC = LayoutC_;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
// check supported wmma shape for the given multiplicand data types
|
||||
static_assert(
|
||||
@ -143,6 +148,7 @@ struct Wmma<
|
||||
using ElementC = int32_t;
|
||||
using LayoutC = LayoutC_;
|
||||
using Operator = cutlass::arch::OpXorPopc;
|
||||
using ArchTag = arch::Sm75;
|
||||
|
||||
// check supported wmma shape for the given multiplicand data types
|
||||
static_assert(
|
||||
|
||||
97
include/cutlass/array_planar_complex.h
Normal file
97
include/cutlass/array_planar_complex.h
Normal file
@ -0,0 +1,97 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing warp-level matrix multiply-accumulate operations.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Array holding planar complex elements
|
||||
template <typename Element_, int N>
|
||||
struct ArrayPlanarComplex {
|
||||
|
||||
/// Underlying real element
|
||||
using Element = Element_;
|
||||
|
||||
/// Number of logical elements
|
||||
static size_t const kElements = N;
|
||||
|
||||
/// Underlying Fragment of real-valued elemenets
|
||||
using ArrayReal = Array<Element, N>;
|
||||
|
||||
public:
|
||||
|
||||
/// Fragment of real-valued elements representing the real part
|
||||
ArrayReal real;
|
||||
|
||||
/// Fragment of real-valued elements representing the imaginary part
|
||||
ArrayReal imag;
|
||||
|
||||
public:
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ArrayPlanarComplex() { }
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ArrayPlanarComplex(
|
||||
ArrayReal const &real_,
|
||||
ArrayReal const &imag_
|
||||
):
|
||||
real(real_), imag(imag_) { }
|
||||
|
||||
/// Sets the array to zero efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear() {
|
||||
real.clear();
|
||||
imag.clear();
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to deduce template arguments
|
||||
template <typename Element, int N>
|
||||
CUTLASS_HOST_DEVICE
|
||||
ArrayPlanarComplex<Element, N>
|
||||
make_ArrayPlanarComplex(Array<Element, N> const &real, Array<Element, N> const &imag) {
|
||||
return ArrayPlanarComplex<Element, N>(real, imag);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -25,7 +25,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <cstdint>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/half.h"
|
||||
@ -351,6 +355,16 @@ CUTLASS_HOST_DEVICE R norm_accumulate(complex<T> const &z, R const &accumulator)
|
||||
static_cast<R>(imag(z)) * static_cast<R>(imag(z));
|
||||
}
|
||||
|
||||
/// Returns the complex conjugate
|
||||
CUTLASS_HOST_DEVICE float conj(float const &z) {
|
||||
return z;
|
||||
}
|
||||
|
||||
/// Returns the complex conjugate
|
||||
CUTLASS_HOST_DEVICE double conj(double const &z) {
|
||||
return z;
|
||||
}
|
||||
|
||||
/// Returns the complex conjugate
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const &z) {
|
||||
@ -414,6 +428,10 @@ CUTLASS_HOST_DEVICE complex<T> sin(complex<T> const &z) {
|
||||
template <typename T>
|
||||
struct RealType< complex<T> > {
|
||||
using Type = T;
|
||||
|
||||
static complex<T> from_real(double x) {
|
||||
return complex<T>(static_cast<T>(x));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -438,5 +456,18 @@ cutlass::complex<double> from_real<cutlass::complex<double> >(double r) {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
struct is_complex {
|
||||
static bool const value = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_complex<complex<T>> {
|
||||
static bool const value = true;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -28,6 +28,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -35,10 +35,21 @@ namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
||||
#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
|
||||
#define CUTLASS_DEVICE __forceinline__ __device__
|
||||
#elif defined(__CUDACC_RTC__)
|
||||
#define CUTLASS_HOST_DEVICE __forceinline__ __device__
|
||||
#define CUTLASS_DEVICE __forceinline__ __device__
|
||||
#else
|
||||
#define CUTLASS_HOST_DEVICE inline
|
||||
#endif
|
||||
|
||||
/// Status code returned by CUTLASS operations
|
||||
enum class Status {
|
||||
kSuccess, ///< Operation was successful.
|
||||
kErrorMisalignedOperand, ///< operands fail alignment requirements.
|
||||
kErrorInvalidDataType, ///< DataType fails requirement.
|
||||
kErrorInvalidLayout, ///< Layout fails alignment requirement.
|
||||
kErrorInvalidProblem, ///< Specified problem size is not supported by operator.
|
||||
kErrorNotSupported, ///< Operation is not supported on current device.
|
||||
@ -48,12 +59,15 @@ enum class Status {
|
||||
};
|
||||
|
||||
/// Convert cutlass status to status strings
|
||||
static inline char const* cutlassGetStatusString(cutlass::Status status) {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static char const* cutlassGetStatusString(cutlass::Status status) {
|
||||
switch (status) {
|
||||
case cutlass::Status::kSuccess:
|
||||
return "Success";
|
||||
case cutlass::Status::kErrorMisalignedOperand:
|
||||
return "Error Misaligned Operand";
|
||||
case cutlass::Status::kErrorInvalidDataType:
|
||||
return "Error Invalid Data Type";
|
||||
case cutlass::Status::kErrorInvalidLayout:
|
||||
return "Error Invalid Layout";
|
||||
case cutlass::Status::kErrorInvalidProblem:
|
||||
@ -79,16 +93,6 @@ static inline char const* cutlassGetStatusString(cutlass::Status status) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
||||
#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
|
||||
#define CUTLASS_DEVICE __forceinline__ __device__
|
||||
#elif defined(__CUDACC_RTC__)
|
||||
#define CUTLASS_HOST_DEVICE __forceinline__ __device__
|
||||
#define CUTLASS_DEVICE __forceinline__ __device__
|
||||
#else
|
||||
#define CUTLASS_HOST_DEVICE inline
|
||||
#endif
|
||||
|
||||
#define CUTLASS_ASSERT(x) assert(x)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -115,6 +119,12 @@ static inline char const* cutlassGetStatusString(cutlass::Status status) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
struct Debug {
|
||||
typename T::X x;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static const int NUM_THREADS_PER_WARP = 32;
|
||||
static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2;
|
||||
@ -131,6 +141,14 @@ int LaneId() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Computes SM number the thread is running on
|
||||
CUTLASS_DEVICE
|
||||
int SmId() {
|
||||
int ret;
|
||||
asm ("mov.u32 %0, %%smid;" : "=r"(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -167,23 +167,23 @@ public:
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
/// Clamping constant value
|
||||
ElementCompute const kClamp = ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
|
||||
|
||||
intermediate = max_accumulator(intermediate, -kClamp);
|
||||
intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
|
||||
ElementCompute const kClamp =
|
||||
ElementCompute((1U << (sizeof_bits<ElementOutput>::value - 1)) - 1);
|
||||
|
||||
intermediate = max_accumulator(intermediate, -kClamp - ElementCompute(1));
|
||||
intermediate = min_accumulator(intermediate, kClamp);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conditional guards to enable partial specialization for packed integers
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
||||
|
||||
/// Applies a linear combination operator to an array of elements then clamps the output before
|
||||
/// converting to the output element type.
|
||||
@ -320,7 +320,142 @@ public:
|
||||
|
||||
#endif // Conditional guards to enable partial specialization for packed integers
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements then clamps
|
||||
/// the output before converting to the output element type.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
/// Note: The below method only works for small k dimensions. The default
|
||||
/// approach is above
|
||||
/// TODO: Add logic to fallback to the default approach
|
||||
template <
|
||||
/// Data type used to load and store< tensors
|
||||
typename ElementOutput_,
|
||||
/// Number of elements computed per operation
|
||||
int Count,
|
||||
/// Rounding mode
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
|
||||
class FastLinearCombinationClamp {
|
||||
public:
|
||||
using ElementOutput = ElementOutput_;
|
||||
using ElementAccumulator = int;
|
||||
using ElementCompute = float;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using ComputeFragment = Array<ElementCompute, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
/// scales accumulators
|
||||
ElementCompute alpha;
|
||||
/// scales source tensor
|
||||
ElementCompute beta;
|
||||
/// pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *alpha_ptr;
|
||||
/// pointer to source scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ElementCompute alpha, ElementCompute beta)
|
||||
: alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
|
||||
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
|
||||
public:
|
||||
/// Constructs the function object, possibly loading from pointers in host
|
||||
/// memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
FastLinearCombinationClamp(Params const ¶ms) {
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const { return beta_ != ElementCompute(0); }
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &accumulator,
|
||||
FragmentOutput const &source,
|
||||
ElementCompute uniform = ElementCompute(0)) const {
|
||||
// Convert source to interal compute numeric type
|
||||
FastNumericArrayConverter<ElementCompute, ElementOutput, kCount, Round>
|
||||
source_converter;
|
||||
FastNumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_converter;
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Compute linear scaling in floating point
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_source;
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
|
||||
minimum<ComputeFragment> min_accumulator;
|
||||
maximum<ComputeFragment> max_accumulator;
|
||||
|
||||
// Float min-max
|
||||
intermediate =
|
||||
mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator,
|
||||
intermediate); // D = alpha * Accum + X
|
||||
|
||||
/// Clamping constant value
|
||||
ElementCompute const kClamp =
|
||||
ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
|
||||
|
||||
intermediate = max_accumulator(intermediate, -kClamp);
|
||||
intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
|
||||
|
||||
// Convert to destination numeric type
|
||||
FastNumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
||||
destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
|
||||
@ -0,0 +1,196 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination operations on planar-complex arrays
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/array_planar_complex.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to arrays of planar-complex elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
/// Note, as with most CUTLASS components for planar complex, the template arguments describe
|
||||
/// the underlying real data type.
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
class LinearCombinationPlanarComplex {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementOutput_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentOutput = ArrayPlanarComplex<ElementOutput, kCount>;
|
||||
using FragmentAccumulator = ArrayPlanarComplex<ElementAccumulator, kCount>;
|
||||
using ComputeFragment = ArrayPlanarComplex<ElementCompute, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
complex<ElementCompute> alpha; ///< scales accumulators
|
||||
complex<ElementCompute> beta; ///< scales source tensor
|
||||
complex<ElementCompute> const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
complex<ElementCompute> const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
complex<ElementCompute> alpha,
|
||||
complex<ElementCompute> beta
|
||||
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
complex<ElementCompute> const *alpha_ptr,
|
||||
complex<ElementCompute> const *beta_ptr
|
||||
): alpha(complex<ElementCompute>()), beta(complex<ElementCompute>()), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
complex<ElementCompute> alpha_;
|
||||
complex<ElementCompute> beta_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationPlanarComplex(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return beta_.real() != ElementCompute(0) || beta_.imag() != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentOutput const &source) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_source(
|
||||
source_converter(source.real),
|
||||
source_converter(source.imag));
|
||||
|
||||
ComputeFragment converted_accumulator(
|
||||
accumulator_converter(accumulator.real),
|
||||
accumulator_converter(accumulator.imag));
|
||||
|
||||
// Perform binary operations
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<Array<ElementCompute, kCount> > mul_op;
|
||||
multiply_add<Array<ElementCompute, kCount> > mul_add_op;
|
||||
|
||||
// complex multiply: I = beta * C
|
||||
intermediate.real = mul_op(beta_.real(), converted_source.real);
|
||||
intermediate.imag = mul_op(beta_.real(), converted_source.imag);
|
||||
|
||||
intermediate.real = mul_add_op(-beta_.imag(), converted_source.imag, intermediate.real);
|
||||
intermediate.imag = mul_add_op( beta_.imag(), converted_source.real, intermediate.imag);
|
||||
|
||||
// complex multiply-add: I = alpha * AB + I
|
||||
intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real, intermediate.real);
|
||||
intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag, intermediate.imag);
|
||||
|
||||
intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real);
|
||||
intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return FragmentOutput(
|
||||
destination_converter(intermediate.real),
|
||||
destination_converter(intermediate.imag));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -180,9 +180,13 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conditional guards to enable partial specialization for packed integers
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
|
||||
((__CUDACC_VER_MAJOR__ > 10) || \
|
||||
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
||||
|
||||
/// Applies a linear combination operator to an array of elements then clamps the output before
|
||||
/// converting to the output element type.
|
||||
///
|
||||
@ -327,6 +331,8 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
#endif // Conditional guards to enable partial specialization for packed integers
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
|
||||
@ -58,16 +58,25 @@ namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization and defines sensible defaults for epilogues for complex*complex case
|
||||
// 4 real-valued mma operations (Complex)
|
||||
// A = (ar + j ai), B (br +j bi), D = AB
|
||||
// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
/// Epilouge Shape
|
||||
typename Shape_,
|
||||
/// Warp-level mma operator
|
||||
typename WarpMmaTensorOp_,
|
||||
/// Number of k partitions
|
||||
int PartitionsK,
|
||||
/// Epilogue output operator
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
/// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load()
|
||||
int ElementsPerAccess,
|
||||
/// Multiply-add operator
|
||||
typename Operator_ = arch::OpMultiplyAddComplex>
|
||||
struct DefaultEpilogueComplexTensorOp {
|
||||
|
||||
using Shape = Shape_;
|
||||
@ -75,6 +84,7 @@ struct DefaultEpilogueComplexTensorOp {
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
using Operator = Operator_;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
||||
|
||||
@ -0,0 +1,197 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Constructs a default epilogue for planar complex outputs.
|
||||
|
||||
This template reuses components for real-valued epilogues and applies them to planar complex
|
||||
output matrices.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/array_planar_complex.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_planar_complex.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues.
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
typename WarpMma_,
|
||||
typename OpcodeClass_,
|
||||
typename ArchTag_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpiloguePlanarComplex;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues.
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
typename WarpMmaOperator_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpiloguePlanarComplex<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm70,
|
||||
PartitionsK,
|
||||
OutputOp_,
|
||||
ElementsPerAccess> {
|
||||
|
||||
using RealEpilogue = DefaultEpilogueVoltaTensorOp<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
PartitionsK,
|
||||
OutputOp_,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
using Epilogue = EpiloguePlanarComplex<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
PartitionsK,
|
||||
typename RealEpilogue::OutputTileIterator,
|
||||
typename RealEpilogue::AccumulatorFragmentIterator,
|
||||
typename RealEpilogue::WarpTileIterator,
|
||||
typename RealEpilogue::SharedLoadIterator,
|
||||
OutputOp_,
|
||||
typename RealEpilogue::Padding
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues.
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
typename WarpMmaOperator_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpiloguePlanarComplex<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
PartitionsK,
|
||||
OutputOp_,
|
||||
ElementsPerAccess> {
|
||||
|
||||
using RealEpilogue = DefaultEpilogueTensorOp<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
PartitionsK,
|
||||
OutputOp_,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
using Epilogue = EpiloguePlanarComplex<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
PartitionsK,
|
||||
typename RealEpilogue::OutputTileIterator,
|
||||
typename RealEpilogue::AccumulatorFragmentIterator,
|
||||
typename RealEpilogue::WarpTileIterator,
|
||||
typename RealEpilogue::SharedLoadIterator,
|
||||
OutputOp_,
|
||||
typename RealEpilogue::Padding
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues.
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
typename WarpMmaOperator_,
|
||||
typename ArchTag_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpiloguePlanarComplex<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
arch::OpClassSimt,
|
||||
ArchTag_,
|
||||
PartitionsK,
|
||||
OutputOp_,
|
||||
ElementsPerAccess> {
|
||||
|
||||
using RealEpilogue = DefaultEpilogueSimt<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
OutputOp_,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
using Epilogue = EpiloguePlanarComplex<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
PartitionsK,
|
||||
typename RealEpilogue::OutputTileIterator,
|
||||
typename RealEpilogue::AccumulatorFragmentIterator,
|
||||
typename RealEpilogue::WarpTileIterator,
|
||||
typename RealEpilogue::SharedLoadIterator,
|
||||
OutputOp_,
|
||||
typename RealEpilogue::Padding
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -23,7 +23,7 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
\brief Epilogue for threadblock scoped GEMMs using WMMA.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
@ -146,6 +146,54 @@ struct DefaultInterleavedThreadMapTensorOp {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines the optimal thread map for TensorOp accumulator layouts
|
||||
template <typename ThreadblockShape_, typename WarpShape_, int PartitionsK,
|
||||
typename Element_, int ElementsPerAccess, int InterleavedK>
|
||||
struct DefaultInterleavedConvThreadMapTensorOp {
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using Element = Element_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kInterleavedK = InterleavedK;
|
||||
|
||||
//
|
||||
// Definitions
|
||||
//
|
||||
|
||||
struct Detail {
|
||||
/// Tensor Operations fundamentally perform operations on 8 rows
|
||||
static int const kTensorOpRows = 8;
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
|
||||
!(ThreadblockShape::kM % WarpShape::kM),
|
||||
"Divisibility");
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount =
|
||||
gemm::GemmShape<ThreadblockShape::kM / WarpShape::kM,
|
||||
ThreadblockShape::kN / WarpShape::kN, kPartitionsK>;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
};
|
||||
|
||||
//
|
||||
// ThreadMap
|
||||
//
|
||||
|
||||
/// ThreadMap to be used by epilogue::MaskedTileIterator satisfying concept
|
||||
/// InterleavedOutputTileThreadMap
|
||||
using Type = InterleavedConvOutputTileThreadMap<
|
||||
MatrixShape<Detail::WarpCount::kM, Detail::WarpCount::kN>,
|
||||
MatrixShape<WarpShape::kM / Detail::kTensorOpRows,
|
||||
WarpShape::kN / InterleavedK>,
|
||||
Detail::kThreads, kElementsPerAccess, sizeof_bits<Element>::value>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
@ -32,7 +32,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@ -74,7 +78,7 @@ template <
|
||||
class Epilogue :
|
||||
public EpilogueBase<
|
||||
Shape_,
|
||||
WarpMmaOperator_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
@ -84,7 +88,7 @@ public:
|
||||
|
||||
using Base = EpilogueBase<
|
||||
Shape_,
|
||||
WarpMmaOperator_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
@ -171,8 +175,9 @@ public:
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
int64_t imag_stride_dest = 0, ///< Arguments required for planar complex case - not used in real-valued case
|
||||
int64_t imag_stride_src = 0) { ///<
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
|
||||
|
||||
@ -32,7 +32,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
@ -58,7 +62,7 @@ namespace threadblock {
|
||||
/// Base class for epilogues defining warp-level
|
||||
template <
|
||||
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
typename WarpShape_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
int PartitionsK, ///< Number of partitions of the K dimension
|
||||
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
||||
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
||||
@ -68,7 +72,7 @@ class EpilogueBase {
|
||||
public:
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
using WarpShape = WarpShape_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using WarpTileIterator = WarpTileIterator_;
|
||||
@ -83,11 +87,10 @@ public:
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename AccumulatorTile::Element;
|
||||
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = gemm::GemmShape<
|
||||
Shape::kM / WarpMmaOperator::Shape::kM,
|
||||
Shape::kN / WarpMmaOperator::Shape::kN,
|
||||
Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
kPartitionsK
|
||||
>;
|
||||
|
||||
@ -144,24 +147,6 @@ public:
|
||||
storage.data(),
|
||||
Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void debug_print() {
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
#pragma unroll 1
|
||||
for (int r = 0; r < Shape::kRow; ++r) {
|
||||
|
||||
#pragma unroll 1
|
||||
for (int c = 0; c < Shape::kColumn; ++c) {
|
||||
|
||||
printf("%d ", int(storage.data()[r * StorageShape::kColumn + c]));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
397
include/cutlass/epilogue/threadblock/epilogue_planar_complex.h
Normal file
397
include/cutlass/epilogue/threadblock/epilogue_planar_complex.h
Normal file
@ -0,0 +1,397 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/array_planar_complex.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue operator for planar-complex output representations.
|
||||
///
|
||||
/// Note, as with most CUTLASS components for planar complex, the template arguments describe
|
||||
/// the underlying real data type.
|
||||
template <
|
||||
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
int PartitionsK, ///< Number of partitions of the K dimension
|
||||
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
||||
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
||||
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
||||
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
||||
typename OutputOp_, ///< Output operator
|
||||
typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
||||
>
|
||||
class EpiloguePlanarComplex {
|
||||
public:
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using WarpTileIterator = WarpTileIterator_;
|
||||
using SharedLoadIterator = SharedLoadIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
using Padding = Padding_;
|
||||
|
||||
/// Output layout is always row-major
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = ArrayPlanarComplex<
|
||||
typename WarpMmaOperator::FragmentC::Element,
|
||||
WarpMmaOperator::FragmentC::kElements
|
||||
>;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename WarpTileIterator::Element;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Output access size
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Tensor reference to destination tensor
|
||||
using TensorRef = typename OutputTileIterator::TensorRef;
|
||||
|
||||
/// Tensor reference to sync tensor
|
||||
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
||||
|
||||
/// Const tensor reference to source tensor
|
||||
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
||||
|
||||
/// Array type used to output
|
||||
using OutputAccessType = Array<
|
||||
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Shape of each warp-level operation
|
||||
using WarpShape = typename WarpMmaOperator::Shape;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = gemm::GemmShape<
|
||||
Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
kPartitionsK
|
||||
>;
|
||||
|
||||
/// Shared memory allocation
|
||||
struct SharedStorage {
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Element type of shared memory
|
||||
using Element = typename WarpTileIterator::Element;
|
||||
|
||||
/// Tensor reference to shared memory allocation
|
||||
using TensorRef = typename WarpTileIterator::TensorRef;
|
||||
|
||||
/// Layout of shared memory allocation
|
||||
using Layout = typename WarpTileIterator::Layout;
|
||||
|
||||
/// Logical shape of the shared memory tile written to by all warps.
|
||||
using Shape = MatrixShape<
|
||||
WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK,
|
||||
WarpCount::kN * WarpTileIterator::Shape::kColumn
|
||||
>;
|
||||
|
||||
/// Shape of the shared memory allocation for the epilogue
|
||||
using StorageShape = MatrixShape<
|
||||
Shape::kRow + Padding::kRow,
|
||||
Shape::kColumn + Padding::kColumn
|
||||
>;
|
||||
|
||||
static int const kImaginaryStride = StorageShape::kCount;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
AlignedBuffer<Element, kImaginaryStride * 2> storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a pointer to the shared memory buffer
|
||||
CUTLASS_DEVICE
|
||||
Element *data() {
|
||||
return storage.data();
|
||||
}
|
||||
|
||||
/// Returns a tensor reference to the shared memory buffer
|
||||
CUTLASS_DEVICE
|
||||
TensorRef reference() {
|
||||
return TensorRef(
|
||||
storage.data(),
|
||||
Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
SharedStorage &shared_storage_;
|
||||
|
||||
/// Loads fragment from shared memory aligned with output tensor
|
||||
SharedLoadIterator shared_load_iterator_;
|
||||
|
||||
/// Stores a warp's fragment of accumulators to SMEM
|
||||
WarpTileIterator warp_tile_iterator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpiloguePlanarComplex(
|
||||
SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
):
|
||||
shared_storage_(shared_storage),
|
||||
shared_load_iterator_(shared_storage.reference(), thread_idx),
|
||||
warp_tile_iterator_(shared_storage.reference(), lane_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
|
||||
//
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
|
||||
int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
|
||||
int warp_m = warp_mn % WarpCount::kM;
|
||||
int warp_n = warp_mn / WarpCount::kM;
|
||||
|
||||
MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
|
||||
|
||||
warp_tile_iterator_.add_tile_offset(warp_offset);
|
||||
}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator_real, ///< Tile iterator for destination
|
||||
OutputTileIterator destination_iterator_imag, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator_real, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
OutputTileIterator source_iterator_imag) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment_real;
|
||||
typename OutputTileIterator::Fragment source_fragment_imag;
|
||||
|
||||
if (!output_op.is_source_needed()) {
|
||||
source_iterator_real.clear_mask();
|
||||
source_iterator_imag.clear_mask();
|
||||
}
|
||||
|
||||
source_fragment_real.clear();
|
||||
source_fragment_imag.clear();
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator_real(accumulators.real);
|
||||
AccumulatorFragmentIterator accum_fragment_iterator_imag(accumulators.imag);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
|
||||
//
|
||||
// Load the source
|
||||
//
|
||||
|
||||
source_iterator_real.load(source_fragment_real);
|
||||
source_iterator_imag.load(source_fragment_imag);
|
||||
|
||||
++source_iterator_real;
|
||||
++source_iterator_imag;
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment_real;
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment_imag;
|
||||
|
||||
accum_fragment_iterator_real.load(accum_fragment_real);
|
||||
accum_fragment_iterator_imag.load(accum_fragment_imag);
|
||||
|
||||
++accum_fragment_iterator_real;
|
||||
++accum_fragment_iterator_imag;
|
||||
|
||||
this->warp_tile_iterator_.store(accum_fragment_real);
|
||||
this->warp_tile_iterator_.store_with_pointer_offset(accum_fragment_imag, SharedStorage::kImaginaryStride);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment_real[kPartitionsK];
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment_imag[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment_real[0]);
|
||||
shared_load_iterator_.load_with_pointer_offset(aligned_accum_fragment_imag[0], SharedStorage::kImaginaryStride);
|
||||
|
||||
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
||||
static_assert(kPartitionsK == 1, "Sliced-K not supported for planar complex at this time");
|
||||
|
||||
//
|
||||
// Compute the output result
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment_real;
|
||||
typename OutputTileIterator::Fragment output_fragment_imag;
|
||||
|
||||
apply_output_operator_(
|
||||
output_fragment_real,
|
||||
output_fragment_imag,
|
||||
output_op,
|
||||
aligned_accum_fragment_real[0],
|
||||
aligned_accum_fragment_imag[0],
|
||||
source_fragment_real,
|
||||
source_fragment_imag);
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
//
|
||||
|
||||
destination_iterator_real.store(output_fragment_real);
|
||||
destination_iterator_imag.store(output_fragment_imag);
|
||||
|
||||
++destination_iterator_real;
|
||||
++destination_iterator_imag;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_(
|
||||
typename OutputTileIterator::Fragment &output_fragment_real,
|
||||
typename OutputTileIterator::Fragment &output_fragment_imag,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment_real,
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment_imag,
|
||||
typename OutputTileIterator::Fragment const &source_fragment_real,
|
||||
typename OutputTileIterator::Fragment const &source_fragment_imag) {
|
||||
|
||||
OutputAccessType *output_frag_real_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment_real);
|
||||
|
||||
OutputAccessType *output_frag_imag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment_imag);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_real_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment_real);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_imag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment_imag);
|
||||
|
||||
OutputAccessType const *source_frag_real_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&source_fragment_real);
|
||||
|
||||
OutputAccessType const *source_frag_imag_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&source_fragment_imag);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
// Call the output operator
|
||||
auto result_fragment = output_op(
|
||||
make_ArrayPlanarComplex(compute_frag_real_ptr[i], compute_frag_imag_ptr[i]),
|
||||
make_ArrayPlanarComplex(source_frag_real_ptr[i], source_frag_imag_ptr[i])
|
||||
);
|
||||
|
||||
output_frag_real_ptr[i] = result_fragment.real;
|
||||
output_frag_imag_ptr[i] = result_fragment.imag;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -85,9 +85,6 @@ class InterleavedEpilogue {
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
/// Output layout is always row-major
|
||||
using Layout = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
||||
|
||||
|
||||
@ -437,11 +437,10 @@ struct OutputTileOptimalThreadMap {
|
||||
/// - minimal address arithmetic
|
||||
/// - minimal predicate calculations
|
||||
///
|
||||
template <typename WarpCount_, typename MmaCount_, int Threads,
|
||||
template <typename WarpCount_, typename Iterations_, int Threads,
|
||||
int ElementsPerAccess, int ElementSize>
|
||||
struct InterleavedOutputTileThreadMap {
|
||||
using WarpCount = WarpCount_;
|
||||
using MmaCount = MmaCount_;
|
||||
|
||||
static int const kWarpSize = 32;
|
||||
static int const kThreads = Threads;
|
||||
@ -460,7 +459,7 @@ struct InterleavedOutputTileThreadMap {
|
||||
// Output
|
||||
//
|
||||
|
||||
using Iterations = MmaCount;
|
||||
using Iterations = Iterations_;
|
||||
|
||||
using Delta = layout::PitchLinearShape<kWarpSize * kElementsPerAccess, 1>;
|
||||
|
||||
@ -491,6 +490,67 @@ struct InterleavedOutputTileThreadMap {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template metaprogram for partitioning a 4D interleaved layout across warps
|
||||
/// to achieve several performance objectives:
|
||||
///
|
||||
/// - coalesced memory accesses in units of 64 Byte lines
|
||||
/// - minimal address arithmetic
|
||||
/// - minimal predicate calculations
|
||||
///
|
||||
template <typename WarpCount_, typename Iterations_, int Threads,
|
||||
int ElementsPerAccess, int ElementSize>
|
||||
struct InterleavedConvOutputTileThreadMap {
|
||||
using WarpCount = WarpCount_;
|
||||
|
||||
static int const kWarpSize = 32;
|
||||
static int const kThreads = Threads;
|
||||
static int const kWarpCount = kThreads / kWarpSize;
|
||||
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kElementSize = ElementSize;
|
||||
|
||||
//
|
||||
// Metaprogram computation
|
||||
//
|
||||
|
||||
struct Detail {};
|
||||
|
||||
//
|
||||
// Output
|
||||
//
|
||||
|
||||
using Iterations = Iterations_;
|
||||
|
||||
using Delta = MatrixShape<kWarpSize / 4, 4 * kElementsPerAccess>;
|
||||
|
||||
/// Initial offset function
|
||||
CUTLASS_HOST_DEVICE
|
||||
static MatrixCoord initial_offset(int thread_idx) {
|
||||
int warp_idx = thread_idx / kWarpSize;
|
||||
int lane_idx = thread_idx % kWarpSize;
|
||||
|
||||
// Compute warp location
|
||||
MatrixCoord warp_footprint{
|
||||
Delta::kRow * Iterations::kRow,
|
||||
Delta::kColumn * Iterations::kColumn,
|
||||
};
|
||||
|
||||
MatrixCoord warp_offset{warp_idx % WarpCount::kRow,
|
||||
warp_idx / WarpCount::kRow};
|
||||
|
||||
// Compute per-lane offset
|
||||
MatrixCoord thread_offset_in_warp{lane_idx / 4,
|
||||
(lane_idx % 4) * kElementsPerAccess};
|
||||
|
||||
MatrixCoord thread_offset_in_threadblock_tile =
|
||||
warp_footprint * warp_offset + thread_offset_in_warp;
|
||||
|
||||
return thread_offset_in_threadblock_tile;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
@ -36,9 +36,9 @@
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
|
||||
|
||||
@ -107,16 +107,16 @@ public:
|
||||
// Data members
|
||||
//
|
||||
|
||||
Index stride; ///< stride in bytes between rows
|
||||
LongIndex stride; ///< stride in bytes between rows
|
||||
|
||||
Index increment_row; ///< increment quantity (in bytes) to advance when moving between rows
|
||||
Index increment_group; ///< increment quantity (in bytes) to advance when moving to the next group
|
||||
Index increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster
|
||||
LongIndex increment_row; ///< increment quantity (in bytes) to advance when moving between rows
|
||||
LongIndex increment_group; ///< increment quantity (in bytes) to advance when moving to the next group
|
||||
LongIndex increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster
|
||||
|
||||
Index advance_row; ///< amount to add to move to the next 'row' position
|
||||
Index advance_group; ///< amount to add to move to the next 'group' position
|
||||
Index advance_cluster; ///< amount to add to move to the next 'cluster' position
|
||||
Index advance_tile; ///< amount to add to move to the next 'tile'
|
||||
LongIndex advance_row; ///< amount to add to move to the next 'row' position
|
||||
LongIndex advance_group; ///< amount to add to move to the next 'group' position
|
||||
LongIndex advance_cluster; ///< amount to add to move to the next 'cluster' position
|
||||
LongIndex advance_tile; ///< amount to add to move to the next 'tile'
|
||||
|
||||
//
|
||||
// Methods
|
||||
@ -125,7 +125,7 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
Status initialize(Index stride_) {
|
||||
|
||||
stride = stride_;
|
||||
stride = LongIndex(stride_);
|
||||
|
||||
increment_row = stride * ThreadMap::Delta::kRow;
|
||||
|
||||
@ -261,8 +261,8 @@ public:
|
||||
|
||||
// Initialize pointer
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
thread_offset.row() * params_.stride +
|
||||
thread_offset.column() * sizeof(AccessType) / kElementsPerAccess;
|
||||
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
|
||||
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
|
||||
|
||||
// Initialize internal state counter
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
@ -276,7 +276,7 @@ public:
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
void load_with_byte_offset(Fragment &frag, int64_t byte_offset) {
|
||||
|
||||
uint8_t *byte_pointer = byte_pointer_;
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
@ -299,7 +299,7 @@ public:
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
|
||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
|
||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
||||
@ -328,9 +328,15 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
load_with_byte_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) {
|
||||
void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) {
|
||||
uint8_t *byte_pointer = byte_pointer_;
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
@ -352,7 +358,7 @@ public:
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
|
||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
|
||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
||||
@ -382,6 +388,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) {
|
||||
store_with_byte_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Advances to the next position to load or store
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIterator &operator++() {
|
||||
@ -440,6 +452,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator used to load output tile from shared memory in epilogue.
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator | InterleavedPredicatedTileIterator | ForwardTileIterator
|
||||
@ -447,7 +460,7 @@ public:
|
||||
template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
typename Element_, ///< Element data type
|
||||
int InterleavedK ///< Number of Interleaved K
|
||||
int InterleavedN ///< Number of Interleaved N
|
||||
>
|
||||
class InterleavedPredicatedTileIterator {
|
||||
public:
|
||||
@ -455,7 +468,7 @@ public:
|
||||
|
||||
using Element = Element_;
|
||||
|
||||
using Layout = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using Layout = layout::ColumnMajorInterleaved<InterleavedN>;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
@ -483,10 +496,10 @@ public:
|
||||
// Data members
|
||||
//
|
||||
|
||||
Index stride; ///< stride in bytes between columns
|
||||
LongIndex stride; ///< stride in bytes between columns
|
||||
|
||||
Index advance_row; ///< amount to add to move to the next 'row' position
|
||||
Index advance_column; ///< amount to add to move to the next 'column' position
|
||||
LongIndex advance_row; ///< amount to add to move to the next 'row' position
|
||||
LongIndex advance_column; ///< amount to add to move to the next 'column' position
|
||||
|
||||
//
|
||||
// Methods
|
||||
@ -494,14 +507,16 @@ public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Status initialize(Index stride_) {
|
||||
stride = stride_;
|
||||
|
||||
stride = LongIndex(stride_);
|
||||
|
||||
advance_row =
|
||||
ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8;
|
||||
|
||||
advance_column =
|
||||
stride_ - ThreadMap::Iterations::kContiguous * kElementsPerAccess *
|
||||
sizeof_bits<Element>::value * ThreadMap::kWarpSize / 8;
|
||||
advance_column = LongIndex(stride_) - ThreadMap::Iterations::kContiguous *
|
||||
kElementsPerAccess *
|
||||
sizeof_bits<Element>::value *
|
||||
ThreadMap::kWarpSize / 8;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
@ -602,10 +617,10 @@ public:
|
||||
):
|
||||
params_(params) {
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) +
|
||||
TensorCoord(threadblock_offset.contiguous() * InterleavedK,
|
||||
threadblock_offset.strided() / InterleavedK);
|
||||
TensorCoord(threadblock_offset.contiguous() * InterleavedN,
|
||||
threadblock_offset.strided() / InterleavedN);
|
||||
|
||||
extent_col_ = extent.strided() / InterleavedK;
|
||||
extent_col_ = extent.strided() / InterleavedN;
|
||||
thread_start_col_ = thread_offset.strided();
|
||||
|
||||
// Initialize predicates
|
||||
@ -613,13 +628,13 @@ public:
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
mask_.predicates[c] =
|
||||
((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) <
|
||||
(extent.contiguous() * InterleavedK));
|
||||
(extent.contiguous() * InterleavedN));
|
||||
}
|
||||
|
||||
// Initialize pointer
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
thread_offset.strided() * params_.stride +
|
||||
thread_offset.contiguous() * sizeof(AccessType) / kElementsPerAccess;
|
||||
LongIndex(thread_offset.strided()) * LongIndex(params_.stride) +
|
||||
LongIndex(thread_offset.contiguous()) * sizeof(AccessType) / kElementsPerAccess;
|
||||
|
||||
// Initialize internal state counter
|
||||
iteration_contiguous_ = iteration_strided_ = 0;
|
||||
@ -634,6 +649,7 @@ public:
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
|
||||
uint8_t *byte_pointer = byte_pointer_;
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
|
||||
|
||||
@ -100,12 +100,28 @@ public:
|
||||
/// Padding quantity
|
||||
using Padding = MatrixShape<
|
||||
0,
|
||||
4 * Policy::kElementsPerAccess>;
|
||||
4 * Policy::kElementsPerAccess
|
||||
#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES
|
||||
+ 1
|
||||
#endif
|
||||
>;
|
||||
|
||||
private:
|
||||
|
||||
#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES
|
||||
/// Storage type for accessing memory
|
||||
using AccessType = AlignedArray<Element, Policy::kElementsPerAccess>;
|
||||
using AccessType = AlignedArray<
|
||||
Element,
|
||||
1
|
||||
>;
|
||||
|
||||
#else
|
||||
/// Storage type for accessing memory
|
||||
using AccessType = AlignedArray<
|
||||
Element,
|
||||
Policy::kElementsPerAccess
|
||||
>;
|
||||
#endif
|
||||
|
||||
//
|
||||
// Data members
|
||||
@ -130,18 +146,21 @@ public:
|
||||
unsigned lane_id
|
||||
):
|
||||
pointer_(reinterpret_cast<AccessType *>(ref.data())),
|
||||
layout_(ref.stride()[0] / Policy::kElementsPerAccess) {
|
||||
layout_(ref.stride()[0] / AccessType::kElements) {
|
||||
|
||||
auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout();
|
||||
MatrixCoord lane_offset = lane_layout.inverse(lane_id);
|
||||
|
||||
pointer_ += layout_(lane_offset);
|
||||
pointer_ += layout_({
|
||||
lane_offset.row(),
|
||||
lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements)
|
||||
});
|
||||
}
|
||||
|
||||
/// Adds a pointer offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorSimt & add_pointer_offset(Index pointer_offset) {
|
||||
pointer_ += pointer_offset / Policy::kElementsPerAccess;
|
||||
pointer_ += pointer_offset / AccessType::kElements;
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -151,7 +170,7 @@ public:
|
||||
|
||||
pointer_ += layout_({
|
||||
tile_offset.row() * Shape::kRow,
|
||||
(tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess)
|
||||
(tile_offset.column() * Shape::kColumn / int(AccessType::kElements))
|
||||
});
|
||||
|
||||
return *this;
|
||||
@ -173,7 +192,7 @@ public:
|
||||
// de-vectorized stores
|
||||
using ScalarAccessType = AlignedArray<Element, 1>;
|
||||
ScalarAccessType const *scalarFragPtr = reinterpret_cast<ScalarAccessType const *>(&frag);
|
||||
ScalarAccessType *scalarPointer = reinterpret_cast<ScalarAccessType *>(pointer_);
|
||||
ScalarAccessType *scalarPointer = reinterpret_cast<ScalarAccessType *>(pointer_) + pointer_offset;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Policy::kAccessesPerIteration; ++n) {
|
||||
@ -187,7 +206,7 @@ public:
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Policy::kAccessesPerIteration; ++n) {
|
||||
pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn] = frag_ptr[n];
|
||||
pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -206,7 +225,7 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Policy::kAccessesPerIteration; ++n) {
|
||||
frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn];
|
||||
frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -187,7 +187,8 @@ public:
|
||||
int access = access_idx % 2;
|
||||
|
||||
int ptr_offset = tile_idx * InterleavedTileShape::kN / Policy::kElementsPerAccess +
|
||||
access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess + access;
|
||||
access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess +
|
||||
access + pointer_offset / Policy::kElementsPerAccess;
|
||||
|
||||
int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx;
|
||||
|
||||
@ -219,7 +220,9 @@ public:
|
||||
int access_quad = access_idx / 2;
|
||||
int access = access_idx % 2;
|
||||
|
||||
int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta + access;
|
||||
int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta +
|
||||
access + pointer_offset / Policy::kElementsPerAccess;
|
||||
|
||||
int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx;
|
||||
|
||||
frag_ptr[frag_idx] = pointer_[ptr_offset];
|
||||
@ -382,7 +385,7 @@ public:
|
||||
|
||||
int ptr_row_offset = row_idx * 2;
|
||||
|
||||
int ptr_offset = layout_({ptr_row_offset, ptr_column_offset});
|
||||
int ptr_offset = layout_({ptr_row_offset, ptr_column_offset}) + pointer_offset / Policy::kElementsPerAccess;
|
||||
|
||||
pointer_[ptr_offset] = frag_ptr[frag_idx];
|
||||
}
|
||||
|
||||
@ -25,7 +25,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <cstdint>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/**
|
||||
@ -201,6 +206,18 @@ void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul,
|
||||
rem = src - (quo * div);
|
||||
}
|
||||
|
||||
/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b
|
||||
CUTLASS_HOST_DEVICE
|
||||
int round_up(int a, int b) {
|
||||
return ((a + b - 1) / b) * b;
|
||||
}
|
||||
|
||||
/// Returns the ceiling of (a / b)
|
||||
CUTLASS_HOST_DEVICE
|
||||
int ceil_div(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Min/Max
|
||||
******************************************************************************/
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
/***************************************************************************************************
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
@ -1207,4 +1207,8 @@ struct multiply_add<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -192,7 +192,8 @@ template <
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,
|
||||
typename ThreadblockSwizzle_ =
|
||||
typename threadblock::GemmCohortThreadblockSwizzle<LayoutA_, LayoutB_>,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
@ -241,6 +242,8 @@ class Gemm {
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static bool const kIsBetaZero = IsBetaZero;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Define the kernel
|
||||
using GemmKernel = typename kernel::DefaultGemm<
|
||||
@ -348,21 +351,23 @@ public:
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
}
|
||||
|
||||
return 0;
|
||||
return bytes;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
@ -426,6 +431,7 @@ public:
|
||||
params_.ref_B.reset(args.ref_B.non_const_ref().data());
|
||||
params_.ref_C.reset(args.ref_C.non_const_ref().data());
|
||||
params_.ref_D.reset(args.ref_D.data());
|
||||
params_.output_op = args.epilogue;
|
||||
params_.semaphore = static_cast<int *>(workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
@ -560,6 +566,8 @@ class Gemm<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static bool const kIsBetaZero = IsBetaZero;
|
||||
|
||||
|
||||
745
include/cutlass/gemm/device/gemm_array.h
Normal file
745
include/cutlass/gemm/device/gemm_array.h
Normal file
@ -0,0 +1,745 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/gemm_array.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may
|
||||
be invoked from host code.
|
||||
|
||||
The contributions of this class are:
|
||||
|
||||
1. At compile time, it maps data types and high-level structural parameters onto
|
||||
specific CUTLASS components.
|
||||
|
||||
2. At runtime, it maps logical arguments to GEMM problems to kernel parameters.
|
||||
|
||||
3. At runtime, it launches kernels on the device.
|
||||
|
||||
The intent is to provide a convenient mechanism for interacting with most plausible GEMM
|
||||
configurations for each supported architecture. Consequently, not all parameters are exposed
|
||||
to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy
|
||||
are selected to tradeoff simplicity of the interface with flexibility. We expect
|
||||
most configurations to be specified at this level. Applications with more exotic requirements
|
||||
may construct their kernels of interest using CUTLASS components at the threadblock, warp,
|
||||
and thread levels of abstraction.
|
||||
|
||||
CUTLASS exposes computations using the functor design pattern in which objects compose some
|
||||
internal state with an overloaded function call operator. This enables decoupling of
|
||||
initialization from execution, possibly reducing overhead during steady state phases of
|
||||
application execution.
|
||||
|
||||
CUTLASS device-level operators expose an Arguments structure encompassing each logical
|
||||
input to the computation. This is distinct from the kernel-level Params structure pattern
|
||||
which contains application-specific precomputed state needed by the device code.
|
||||
|
||||
Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN
|
||||
is as follows:
|
||||
|
||||
//
|
||||
// Instantiate the CUTLASS GEMM operator.
|
||||
//
|
||||
|
||||
cutlass::gemm::device::Gemm<
|
||||
float,
|
||||
cutlass::layout::ColumnMajor,
|
||||
float,
|
||||
cutlass::layout::ColumnMajor,
|
||||
float,
|
||||
cutlass::layout::ColumnMajor
|
||||
> gemm_op;
|
||||
|
||||
//
|
||||
// Launch the GEMM operation on the device
|
||||
//
|
||||
|
||||
cutlass::Status status = gemm_op({
|
||||
{m, n, k}, // GemmCoord problem_size,
|
||||
{A, lda}, // TensorRef<float, layout::ColumnMajor> ref_A,
|
||||
{B, ldb}, // TensorRef<float, layout::ColumnMajor> ref_B,
|
||||
{C, ldc}, // TensorRef<float, layout::ColumnMajor> ref_C,
|
||||
{D, ldd}, // TensorRef<float, layout::ColumnMajor> ref_D,
|
||||
{alpha, beta} // EpilogueOutputOp::Params epilogue_op_params
|
||||
});
|
||||
|
||||
|
||||
A simplified view of the template is listed below.
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages
|
||||
>
|
||||
class Gemm;
|
||||
*/
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_ = ElementC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_ = arch::OpClassSimt,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_ = arch::Sm70,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator
|
||||
>
|
||||
class GemmArray {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
using Operator = Operator_;
|
||||
|
||||
/// Define the kernel
|
||||
using DefaultGemmKernel = typename kernel::DefaultGemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
false,
|
||||
Operator,
|
||||
false
|
||||
>::GemmKernel;
|
||||
|
||||
using GemmKernel = kernel::GemmArray<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle>;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord problem_size;
|
||||
|
||||
ElementA const * const *ptr_A;
|
||||
LayoutA layout_A;
|
||||
|
||||
ElementB const * const *ptr_B;
|
||||
LayoutB layout_B;
|
||||
|
||||
ElementC const * const *ptr_C;
|
||||
LayoutC layout_C;
|
||||
|
||||
ElementC * const * ptr_D;
|
||||
LayoutC layout_D;
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
int batch_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord problem_size_,
|
||||
ElementA const * const *ptr_A_,
|
||||
LayoutA layout_A_,
|
||||
ElementB const * const *ptr_B_,
|
||||
LayoutB layout_B_,
|
||||
ElementC const * const *ptr_C_,
|
||||
LayoutC layout_C_,
|
||||
ElementC * const * ptr_D_,
|
||||
LayoutC layout_D_,
|
||||
typename EpilogueOutputOp::Params epilogue_,
|
||||
int batch_count_
|
||||
):
|
||||
problem_size(problem_size_),
|
||||
ptr_A(ptr_A_),
|
||||
layout_A(layout_A_),
|
||||
ptr_B(ptr_B_),
|
||||
layout_B(layout_B_),
|
||||
ptr_C(ptr_C_),
|
||||
layout_C(layout_C_),
|
||||
ptr_D(ptr_D_),
|
||||
layout_D(layout_D_),
|
||||
epilogue(epilogue_),
|
||||
batch_count(batch_count_) { }
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
GemmArray() { }
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
if (args.layout_A.stride(0) % kAlignmentA) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (args.layout_B.stride(0) % kAlignmentB) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (args.layout_C.stride(0) % kAlignmentC) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (args.layout_D.stride(0) % kAlignmentC) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) ||
|
||||
(args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) ||
|
||||
(args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) {
|
||||
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
args.batch_count,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params{
|
||||
args.problem_size,
|
||||
grid_shape,
|
||||
args.ptr_A,
|
||||
args.layout_A,
|
||||
args.ptr_B,
|
||||
args.layout_B,
|
||||
args.ptr_C,
|
||||
args.layout_C,
|
||||
args.ptr_D,
|
||||
args.layout_D,
|
||||
args.epilogue,
|
||||
args.batch_count
|
||||
};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
args.batch_count,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
|
||||
|
||||
params_ = typename GemmKernel::Params{
|
||||
args.problem_size,
|
||||
grid_shape,
|
||||
args.ptr_A,
|
||||
args.layout_A,
|
||||
args.ptr_B,
|
||||
args.layout_B,
|
||||
args.ptr_C,
|
||||
args.layout_C,
|
||||
args.ptr_D,
|
||||
args.layout_D,
|
||||
args.epilogue,
|
||||
args.batch_count
|
||||
};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
cudaError_t result;
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parital specialization for column-major output exchanges problem size and operand.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB,
|
||||
typename Operator_
|
||||
>
|
||||
class GemmArray<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
ElementC_,
|
||||
layout::ColumnMajor,
|
||||
ElementAccumulator_,
|
||||
OperatorClass_,
|
||||
ArchTag_,
|
||||
ThreadblockShape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_,
|
||||
Stages,
|
||||
AlignmentA,
|
||||
AlignmentB,
|
||||
Operator_
|
||||
> {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = layout::ColumnMajor;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static int const kStages = Stages;
|
||||
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
static bool const kSplitKSerial = false;
|
||||
|
||||
//
|
||||
using UnderlyingOperator = GemmArray<
|
||||
ElementB,
|
||||
typename layout::LayoutTranspose<LayoutB>::type,
|
||||
ElementA,
|
||||
typename layout::LayoutTranspose<LayoutA>::type,
|
||||
ElementC,
|
||||
layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
kAlignmentB,
|
||||
kAlignmentA
|
||||
>;
|
||||
|
||||
using UnderlyingArguments = typename UnderlyingOperator::Arguments;
|
||||
using GemmKernel = typename UnderlyingOperator::GemmKernel;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord problem_size;
|
||||
|
||||
ElementA const * const *ptr_A;
|
||||
LayoutA layout_A;
|
||||
|
||||
ElementB const * const *ptr_B;
|
||||
LayoutB layout_B;
|
||||
|
||||
ElementC const * const *ptr_C;
|
||||
LayoutC layout_C;
|
||||
|
||||
ElementC * const * ptr_D;
|
||||
LayoutC layout_D;
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
int batch_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord problem_size_,
|
||||
ElementA const * const *ptr_A_,
|
||||
LayoutA layout_A_,
|
||||
ElementB const * const *ptr_B_,
|
||||
LayoutB layout_B_,
|
||||
ElementC const * const *ptr_C_,
|
||||
LayoutC layout_C_,
|
||||
ElementC * const * ptr_D_,
|
||||
LayoutC layout_D_,
|
||||
typename EpilogueOutputOp::Params epilogue_,
|
||||
int batch_count_
|
||||
):
|
||||
problem_size(problem_size_),
|
||||
ptr_A(ptr_A_),
|
||||
layout_A(layout_A_),
|
||||
ptr_B(ptr_B_),
|
||||
layout_B(layout_B_),
|
||||
ptr_C(ptr_C_),
|
||||
layout_C(layout_C_),
|
||||
ptr_D(ptr_D_),
|
||||
layout_D(layout_D_),
|
||||
epilogue(epilogue_),
|
||||
batch_count(batch_count_) { }
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
UnderlyingOperator underlying_operator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
GemmArray() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
static UnderlyingArguments to_underlying_arguments(Arguments const &args) {
|
||||
|
||||
GemmCoord problem_size{
|
||||
args.problem_size.n(),
|
||||
args.problem_size.m(),
|
||||
args.problem_size.k()
|
||||
};
|
||||
|
||||
return UnderlyingArguments(
|
||||
problem_size,
|
||||
args.ptr_B,
|
||||
args.layout_B.stride(),
|
||||
args.ptr_A,
|
||||
args.layout_A.stride(),
|
||||
args.ptr_C,
|
||||
args.layout_C.stride(),
|
||||
args.ptr_D,
|
||||
args.layout_D.stride(),
|
||||
args.epilogue,
|
||||
args.batch_count
|
||||
);
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
return underlying_operator_.update(to_underlying_arguments(args), workspace);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -370,8 +370,8 @@ public:
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
args.batch_count,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.batch_count);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params{
|
||||
|
||||
@ -201,6 +201,8 @@ template <
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB = ComplexTransform::kNone,
|
||||
/// Multiply-add operator
|
||||
typename Operator_ = arch::OpMultiplyAddComplex,
|
||||
/// If true, kernel supports split-K with serial reduction
|
||||
bool SplitKSerial = false
|
||||
>
|
||||
@ -228,7 +230,11 @@ class GemmComplex {
|
||||
static int const kStages = Stages;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
using Operator = Operator_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static int const kAlignmentA = 1;
|
||||
static int const kAlignmentB = 1;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
|
||||
/// Define the kernel
|
||||
using GemmKernel = typename kernel::DefaultGemmComplex<
|
||||
@ -249,6 +255,7 @@ class GemmComplex {
|
||||
kStages,
|
||||
kTransformA,
|
||||
kTransformB,
|
||||
Operator,
|
||||
kSplitKSerial
|
||||
>::GemmKernel;
|
||||
|
||||
@ -498,6 +505,8 @@ template <
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Multiply-add operator
|
||||
typename Operator_,
|
||||
/// If true, kernel supports split-K as a serial reduction
|
||||
bool SplitKSerial
|
||||
>
|
||||
@ -519,6 +528,7 @@ class GemmComplex<
|
||||
Stages,
|
||||
TransformA,
|
||||
TransformB,
|
||||
Operator_,
|
||||
SplitKSerial
|
||||
> {
|
||||
public:
|
||||
@ -542,6 +552,7 @@ public:
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static int const kStages = Stages;
|
||||
using Operator = Operator_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
using UnderlyingOperator = GemmComplex<
|
||||
@ -562,8 +573,15 @@ public:
|
||||
Stages,
|
||||
TransformA,
|
||||
TransformB,
|
||||
Operator,
|
||||
SplitKSerial
|
||||
>;
|
||||
|
||||
static int const kAlignmentA = UnderlyingOperator::kAlignmentB;
|
||||
static int const kAlignmentB = UnderlyingOperator::kAlignmentA;
|
||||
static int const kAlignmentC = UnderlyingOperator::kAlignmentC;
|
||||
static ComplexTransform const kTransformA = UnderlyingOperator::kTransformB;
|
||||
static ComplexTransform const kTransformB = UnderlyingOperator::kTransformA;
|
||||
|
||||
using UnderlyingArguments = typename UnderlyingOperator::Arguments;
|
||||
using GemmKernel = typename UnderlyingOperator::GemmKernel;
|
||||
|
||||
372
include/cutlass/gemm/device/gemm_universal.h
Normal file
372
include/cutlass/gemm/device/gemm_universal.h
Normal file
@ -0,0 +1,372 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_base.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!
|
||||
The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
*/
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_ = ElementC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_ = arch::OpClassSimt,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_ = arch::Sm70,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB = ComplexTransform::kNone
|
||||
>
|
||||
class GemmUniversal :
|
||||
GemmUniversalBase<
|
||||
typename kernel::DefaultGemmUniversal<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
TransformA,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
TransformB,
|
||||
AlignmentB,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
ElementAccumulator_,
|
||||
OperatorClass_,
|
||||
ArchTag_,
|
||||
ThreadblockShape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_,
|
||||
Stages,
|
||||
Operator_
|
||||
>::GemmKernel
|
||||
> {
|
||||
|
||||
public:
|
||||
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
using Base = GemmUniversalBase<
|
||||
typename kernel::DefaultGemmUniversal<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
TransformA,
|
||||
AlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
TransformB,
|
||||
AlignmentB,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
ElementAccumulator_,
|
||||
OperatorClass_,
|
||||
ArchTag_,
|
||||
ThreadblockShape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_,
|
||||
Stages,
|
||||
Operator_
|
||||
>::GemmKernel
|
||||
>;
|
||||
|
||||
using Arguments = typename Base::Arguments;
|
||||
using GemmKernel = typename Base::GemmKernel;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parital specialization for column-major output exchanges problem size and operand.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB>
|
||||
class GemmUniversal<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
|
||||
layout::ColumnMajor, // partially specialized on LayoutC
|
||||
ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
|
||||
WarpShape_, InstructionShape_, EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
|
||||
Operator_, TransformA, TransformB> {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = layout::ColumnMajor;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
using UnderlyingOperator = typename GemmUniversal<
|
||||
ElementB,
|
||||
typename layout::LayoutTranspose<LayoutB>::type,
|
||||
ElementA,
|
||||
typename layout::LayoutTranspose<LayoutA>::type,
|
||||
ElementC,
|
||||
layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
kAlignmentB,
|
||||
kAlignmentA,
|
||||
Operator,
|
||||
kTransformB,
|
||||
kTransformA
|
||||
>::Base;
|
||||
|
||||
using GemmKernel = typename UnderlyingOperator::GemmKernel;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename UnderlyingOperator::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
UnderlyingOperator underlying_operator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversal() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
return underlying_operator_.update(to_underlying_arguments(args), workspace);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
172
include/cutlass/gemm/device/gemm_universal_adapter.h
Normal file
172
include/cutlass/gemm/device/gemm_universal_adapter.h
Normal file
@ -0,0 +1,172 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_base.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmKernel_>
|
||||
class GemmUniversalAdapter {
|
||||
public:
|
||||
|
||||
using GemmKernel = GemmKernel_;
|
||||
|
||||
static_assert(std::is_same<typename GemmKernel::LayoutC, cutlass::layout::RowMajor>::value,
|
||||
"Universal adapter expects the kernel to be row-major and transposes its arguments.");
|
||||
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
using WarpShape = typename GemmKernel::WarpShape;
|
||||
using InstructionShape = typename GemmKernel::InstructionShape;
|
||||
|
||||
using OperatorClass = typename GemmKernel::OperatorClass;
|
||||
using ArchTag = typename GemmKernel::ArchTag;
|
||||
|
||||
// Type, layout, and complex transform deliberately exchanged with B
|
||||
using ElementA = typename GemmKernel::ElementB;
|
||||
using LayoutA = typename layout::LayoutTranspose<typename GemmKernel::LayoutB>::type;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = GemmKernel::kTransformB;
|
||||
static int const kAlignmentA = GemmKernel::kAlignmentA;
|
||||
|
||||
// Type, layout, and complex transform deliberately exchanged with A
|
||||
using ElementB = typename GemmKernel::ElementA;
|
||||
using LayoutB = typename layout::LayoutTranspose<typename GemmKernel::LayoutA>::type;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = GemmKernel::kTransformA;
|
||||
static int const kAlignmentB = GemmKernel::kAlignmentB;
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
static int const kAlignmentC = GemmKernel::kAlignmentC;
|
||||
|
||||
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
static int const kStages = GemmKernel::Mma::kStages;
|
||||
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
using Operator = typename GemmKernel::Operator;
|
||||
|
||||
using UnderlyingOperator = GemmUniversalBase<GemmKernel>;
|
||||
using Arguments = typename UnderlyingOperator::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
UnderlyingOperator underlying_operator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalAdapter() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
return underlying_operator_.update(to_underlying_arguments(args), workspace);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
339
include/cutlass/gemm/device/gemm_universal_base.h
Normal file
339
include/cutlass/gemm/device/gemm_universal_base.h
Normal file
@ -0,0 +1,339 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template <typename GemmKernel_>
|
||||
class GemmUniversalBase {
|
||||
public:
|
||||
|
||||
using GemmKernel = GemmKernel_;
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
|
||||
using ElementA = typename GemmKernel::ElementA;
|
||||
using LayoutA = typename GemmKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
|
||||
|
||||
using ElementB = typename GemmKernel::ElementB;
|
||||
using LayoutB = typename GemmKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = typename GemmKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
using Operator = typename GemmKernel::Operator;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename GemmKernel::Arguments;
|
||||
|
||||
protected:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
protected:
|
||||
|
||||
/// Private helper to obtain the grid dimensions with fix-up for split-K
|
||||
static void get_grid_shape_(gemm::GemmCoord &grid_tiled_shape, int &gemm_k_size, Arguments const &args) {
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.batch_count);
|
||||
|
||||
gemm_k_size = args.problem_size.k();
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size) {
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalBase() { }
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
// Split-K parallel always requires a temporary workspace
|
||||
workspace_bytes =
|
||||
sizeof(ElementC) *
|
||||
size_t(args.batch_stride_D) *
|
||||
size_t(grid_tiled_shape.k());
|
||||
}
|
||||
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Serial split-K only requires a temporary workspace if the number of partitions along the
|
||||
// GEMM K dimension is greater than one.
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
return threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size <= (48 << 10)) {
|
||||
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
Kernel<GemmKernel>,
|
||||
GemmKernel::kThreadCount,
|
||||
smem_size);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
return max_active_blocks;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
Kernel<GemmKernel>,
|
||||
GemmKernel::kThreadCount,
|
||||
0);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (smem_capacity < 0) {
|
||||
int device_idx = 0;
|
||||
result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp properties;
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
return std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes) {
|
||||
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm) {
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get CUDA grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params(
|
||||
args,
|
||||
grid_tiled_shape,
|
||||
gemm_k_size,
|
||||
static_cast<int *>(workspace)
|
||||
);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_.update(args, workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
cudaError_t result;
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -394,6 +394,15 @@ struct BatchedGemmCoord : public Coord<4, int> {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class GemmUniversalMode {
|
||||
kGemm,
|
||||
kGemmSplitKParallel,
|
||||
kBatched,
|
||||
kArray
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
|
||||
118
include/cutlass/gemm/kernel/default_gemm_complex.h
Normal file
118
include/cutlass/gemm/kernel/default_gemm_complex.h
Normal file
@ -0,0 +1,118 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Multiply-add operator
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial
|
||||
>
|
||||
struct DefaultGemmComplex;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,229 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_planar_complex.h"
|
||||
#include "cutlass/gemm/kernel/gemm_planar_complex_array.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_planar_complex.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Math operation performed by GEMM (e.g. arch::OpMultiplyAdd)
|
||||
typename Operator,
|
||||
/// Conditional enabling to switch between stages
|
||||
typename Enable = void
|
||||
>
|
||||
struct DefaultGemmPlanarComplexUniversal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for pipelined mainloop
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultGemmPlanarComplexUniversal<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
TransformA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
TransformB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
Operator,
|
||||
typename std::enable_if<(Stages <= 2)>::type
|
||||
> {
|
||||
|
||||
/// Define planar complex valued variants instead
|
||||
using Mma = typename gemm::threadblock::DefaultMmaPlanarComplexPipelined<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
LayoutC,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Stages,
|
||||
TransformA,
|
||||
TransformB,
|
||||
Operator
|
||||
>::ThreadblockMma;
|
||||
|
||||
/// Planar complex epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpiloguePlanarComplex<
|
||||
ThreadblockShape,
|
||||
typename Mma::Policy::Operator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape::kK / WarpShape::kK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
using GemmKernel = kernel::GemmPlanarComplex<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle
|
||||
>;
|
||||
|
||||
// Array variant
|
||||
using GemmArrayKernel = kernel::GemmPlanarComplexArray<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
308
include/cutlass/gemm/kernel/default_gemm_universal.h
Normal file
308
include/cutlass/gemm/kernel/default_gemm_universal.h
Normal file
@ -0,0 +1,308 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
typename Enable = void
|
||||
>
|
||||
struct DefaultGemmUniversal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Real-valued GEMM kernels
|
||||
//
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultGemmUniversal<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ComplexTransform::kNone, // transform A
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ComplexTransform::kNone, // transform B
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
Operator,
|
||||
typename std::enable_if< ! cutlass::is_complex<ElementAccumulator>::value>::type
|
||||
> {
|
||||
|
||||
using DefaultGemmKernel = typename kernel::DefaultGemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
true,
|
||||
Operator,
|
||||
false
|
||||
>::GemmKernel;
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
using GemmKernel = kernel::GemmUniversal<
|
||||
typename DefaultGemmKernel::Mma,
|
||||
typename DefaultGemmKernel::Epilogue,
|
||||
ThreadblockSwizzle
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Complex-valued GEMM kernels
|
||||
//
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultGemmUniversal<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
TransformA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
TransformB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
Operator,
|
||||
typename std::enable_if<cutlass::is_complex<ElementAccumulator>::value>::type
|
||||
> {
|
||||
|
||||
using DefaultGemmKernel = typename kernel::DefaultGemmComplex<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
TransformA,
|
||||
TransformB,
|
||||
Operator,
|
||||
false
|
||||
>::GemmKernel;
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
using GemmKernel = kernel::GemmUniversal<
|
||||
typename DefaultGemmKernel::Mma,
|
||||
typename DefaultGemmKernel::Epilogue,
|
||||
ThreadblockSwizzle
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
0
include/cutlass/gemm/kernel/default_gemv.h
Normal file → Executable file
0
include/cutlass/gemm/kernel/default_gemv.h
Normal file → Executable file
@ -83,7 +83,7 @@ struct Gemm {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
Params(): semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
@ -94,7 +94,7 @@ struct Gemm {
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D,
|
||||
typename OutputOp::Params output_op = typename OutputOp::Params(),
|
||||
int *semaphore = nullptr
|
||||
int *workspace = nullptr
|
||||
):
|
||||
problem_size(problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
@ -106,13 +106,14 @@ struct Gemm {
|
||||
ref_C(ref_C),
|
||||
params_D(ref_D.layout()),
|
||||
ref_D(ref_D),
|
||||
output_op(output_op),
|
||||
semaphore(semaphore) {
|
||||
output_op(output_op) {
|
||||
|
||||
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
||||
|
||||
semaphore = workspace;
|
||||
}
|
||||
};
|
||||
|
||||
@ -220,7 +221,9 @@ struct Gemm {
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
int warp_idx = threadIdx.x / 32;
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
|
||||
253
include/cutlass/gemm/kernel/gemm_array.h
Normal file
253
include/cutlass/gemm/kernel/gemm_array.h
Normal file
@ -0,0 +1,253 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmArray {
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using OutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorA::Element const * const * ptr_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Mma::IteratorB::Element const * const * ptr_B;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::Element const * const * ptr_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::OutputTileIterator::Element * const * ptr_D;
|
||||
int64_t stride_D;
|
||||
typename OutputOp::Params epilogue;
|
||||
int batch_count;
|
||||
int gemm_k_iterations;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const & problem_size_,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape_,
|
||||
typename Mma::IteratorA::Element const * const * ptr_A_,
|
||||
typename Mma::IteratorA::Layout layout_A,
|
||||
typename Mma::IteratorB::Element const * const * ptr_B_,
|
||||
typename Mma::IteratorB::Layout layout_B,
|
||||
typename Epilogue::OutputTileIterator::Element const * const * ptr_C_,
|
||||
typename Epilogue::OutputTileIterator::Layout layout_C,
|
||||
typename Epilogue::OutputTileIterator::Element * const * ptr_D_,
|
||||
typename Epilogue::OutputTileIterator::Layout layout_D,
|
||||
typename OutputOp::Params epilogue_,
|
||||
int batch_count_
|
||||
):
|
||||
problem_size(problem_size_),
|
||||
grid_tiled_shape(grid_tiled_shape_),
|
||||
params_A(layout_A),
|
||||
ptr_A(ptr_A_),
|
||||
params_B(layout_B),
|
||||
ptr_B(ptr_B_),
|
||||
params_C(layout_C),
|
||||
ptr_C(ptr_C_),
|
||||
params_D(layout_D),
|
||||
ptr_D(ptr_D_),
|
||||
epilogue(epilogue_),
|
||||
batch_count(batch_count_),
|
||||
gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmArray() { }
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension
|
||||
for (int batch_idx = threadblock_swizzle.get_batch_idx();
|
||||
batch_idx < params.batch_count;
|
||||
batch_idx += gridDim.z) {
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
0
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
0,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
const_cast<typename Mma::IteratorA::Element *>(params.ptr_A[batch_idx]),
|
||||
params.problem_size.mk(),
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
const_cast<typename Mma::IteratorB::Element *>(params.ptr_B[batch_idx]),
|
||||
params.problem_size.kn(),
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
OutputOp output_op(params.epilogue);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
// Tile iterator writing to output tile
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.params_C,
|
||||
const_cast<typename Epilogue::OutputTileIterator::Element *>(params.ptr_C[batch_idx]),
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to output tile
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.params_D,
|
||||
params.ptr_D[batch_idx],
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// run efficient epilogue
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -193,8 +193,10 @@ struct GemmBatched {
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
int warp_idx = threadIdx.x / 32;
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
693
include/cutlass/gemm/kernel/gemm_planar_complex.h
Normal file
693
include/cutlass/gemm/kernel/gemm_planar_complex.h
Normal file
@ -0,0 +1,693 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmPlanarComplex {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
using Operator = typename Mma::Operator;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementB>::value);
|
||||
|
||||
//
|
||||
// Additional types needed for reflection
|
||||
//
|
||||
|
||||
using ElementAccumulator = typename Mma::Policy::Operator::ElementC;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::Shape;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
//
|
||||
// Arguments structure
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
|
||||
void const * ptr_A_real;
|
||||
void const * ptr_A_imag;
|
||||
|
||||
void const * ptr_B_real;
|
||||
void const * ptr_B_imag;
|
||||
|
||||
void const * ptr_C_real;
|
||||
void const * ptr_C_imag;
|
||||
|
||||
void * ptr_D_real;
|
||||
void * ptr_D_imag;
|
||||
|
||||
int lda_real;
|
||||
int lda_imag;
|
||||
int ldb_real;
|
||||
int ldb_imag;
|
||||
int ldc_real;
|
||||
int ldc_imag;
|
||||
int ldd_real;
|
||||
int ldd_imag;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_A_imag;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_B_imag;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_C_imag;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_D_imag;
|
||||
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A_real(nullptr),
|
||||
ptr_A_imag(nullptr),
|
||||
ptr_B_real(nullptr),
|
||||
ptr_B_imag(nullptr),
|
||||
ptr_C_real(nullptr),
|
||||
ptr_C_imag(nullptr),
|
||||
ptr_D_real(nullptr),
|
||||
ptr_D_imag(nullptr)
|
||||
{ }
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueOutputOp::Params epilogue,
|
||||
void const * ptr_A_real,
|
||||
void const * ptr_A_imag,
|
||||
void const * ptr_B_real,
|
||||
void const * ptr_B_imag,
|
||||
void const * ptr_C_real,
|
||||
void const * ptr_C_imag,
|
||||
void * ptr_D_real,
|
||||
void * ptr_D_imag,
|
||||
int lda_real,
|
||||
int lda_imag,
|
||||
int ldb_real,
|
||||
int ldb_imag,
|
||||
int ldc_real,
|
||||
int ldc_imag,
|
||||
int ldd_real,
|
||||
int ldd_imag,
|
||||
int64_t batch_stride_A = 0,
|
||||
int64_t batch_stride_A_imag = 0,
|
||||
int64_t batch_stride_B = 0,
|
||||
int64_t batch_stride_B_imag = 0,
|
||||
int64_t batch_stride_C = 0,
|
||||
int64_t batch_stride_C_imag = 0,
|
||||
int64_t batch_stride_D = 0,
|
||||
int64_t batch_stride_D_imag = 0
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue(epilogue),
|
||||
ptr_A_real(ptr_A_real),
|
||||
ptr_A_imag(ptr_A_imag),
|
||||
ptr_B_real(ptr_B_real),
|
||||
ptr_B_imag(ptr_B_imag),
|
||||
ptr_C_real(ptr_C_real),
|
||||
ptr_C_imag(ptr_C_imag),
|
||||
ptr_D_real(ptr_D_real),
|
||||
ptr_D_imag(ptr_D_imag),
|
||||
lda_real(lda_real),
|
||||
lda_imag(lda_imag),
|
||||
ldb_real(ldb_real),
|
||||
ldb_imag(ldb_imag),
|
||||
ldc_real(ldc_real),
|
||||
ldc_imag(ldc_imag),
|
||||
ldd_real(ldd_real),
|
||||
ldd_imag(ldd_imag),
|
||||
batch_stride_A(batch_stride_A),
|
||||
batch_stride_A_imag(batch_stride_A_imag),
|
||||
batch_stride_B(batch_stride_B),
|
||||
batch_stride_B_imag(batch_stride_B_imag),
|
||||
batch_stride_C(batch_stride_C),
|
||||
batch_stride_C_imag(batch_stride_C_imag),
|
||||
batch_stride_D(batch_stride_D),
|
||||
batch_stride_D_imag(batch_stride_D_imag) {
|
||||
|
||||
}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
Arguments args(*this);
|
||||
|
||||
std::swap(args.problem_size.m(), args.problem_size.n());
|
||||
std::swap(args.ptr_A_real, args.ptr_B_real);
|
||||
std::swap(args.ptr_A_imag, args.ptr_B_imag);
|
||||
std::swap(args.lda_real, args.ldb_real);
|
||||
std::swap(args.lda_imag, args.ldb_imag);
|
||||
std::swap(args.batch_stride_A, args.batch_stride_B);
|
||||
std::swap(args.batch_stride_A_imag, args.batch_stride_B_imag);
|
||||
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
|
||||
typename Mma::IteratorA::Params params_A_real;
|
||||
typename Mma::IteratorA::Params params_A_imag;
|
||||
typename Mma::IteratorB::Params params_B_real;
|
||||
typename Mma::IteratorB::Params params_B_imag;
|
||||
typename Epilogue::OutputTileIterator::Params params_C_real;
|
||||
typename Epilogue::OutputTileIterator::Params params_C_imag;
|
||||
typename Epilogue::OutputTileIterator::Params params_D_real;
|
||||
typename Epilogue::OutputTileIterator::Params params_D_imag;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A_real;
|
||||
void * ptr_A_imag;
|
||||
void * ptr_B_real;
|
||||
void * ptr_B_imag;
|
||||
void * ptr_C_real;
|
||||
void * ptr_C_imag;
|
||||
void * ptr_D_real;
|
||||
void * ptr_D_imag;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_A_imag;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_B_imag;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_C_imag;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_D_imag;
|
||||
|
||||
int *semaphore;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A_real(nullptr),
|
||||
ptr_A_imag(nullptr),
|
||||
ptr_B_real(nullptr),
|
||||
ptr_B_imag(nullptr),
|
||||
ptr_C_real(nullptr),
|
||||
ptr_C_imag(nullptr),
|
||||
ptr_D_real(nullptr),
|
||||
ptr_D_imag(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_A_imag(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_B_imag(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_C_imag(0),
|
||||
batch_stride_D(0),
|
||||
batch_stride_D_imag(0),
|
||||
semaphore(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
params_A_real(args.lda_real),
|
||||
params_A_imag(args.lda_imag),
|
||||
params_B_real(args.ldb_real),
|
||||
params_B_imag(args.ldb_imag),
|
||||
params_C_real(args.ldc_real),
|
||||
params_C_imag(args.ldc_imag),
|
||||
params_D_real(args.ldd_real),
|
||||
params_D_imag(args.ldd_imag),
|
||||
output_op(args.epilogue),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A_real(const_cast<void *>(args.ptr_A_real)),
|
||||
ptr_A_imag(const_cast<void *>(args.ptr_A_imag)),
|
||||
ptr_B_real(const_cast<void *>(args.ptr_B_real)),
|
||||
ptr_B_imag(const_cast<void *>(args.ptr_B_imag)),
|
||||
ptr_C_real(const_cast<void *>(args.ptr_C_real)),
|
||||
ptr_C_imag(const_cast<void *>(args.ptr_C_imag)),
|
||||
ptr_D_real(args.ptr_D_real),
|
||||
ptr_D_imag(args.ptr_D_imag),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_A_imag(args.batch_stride_A_imag),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_B_imag(args.batch_stride_B_imag),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_C_imag(args.batch_stride_C_imag),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
batch_stride_D_imag(args.batch_stride_D_imag),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
|
||||
}
|
||||
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
|
||||
ptr_A_real = const_cast<void *>(args.ptr_A_real);
|
||||
ptr_A_imag = const_cast<void *>(args.ptr_A_imag);
|
||||
|
||||
ptr_B_real = const_cast<void *>(args.ptr_B_real);
|
||||
ptr_B_imag = const_cast<void *>(args.ptr_B_imag);
|
||||
|
||||
ptr_C_real = const_cast<void *>(args.ptr_C_real);
|
||||
ptr_C_imag = const_cast<void *>(args.ptr_C_imag);
|
||||
|
||||
ptr_D_real = const_cast<void *>(args.ptr_D_real);
|
||||
ptr_D_imag = const_cast<void *>(args.ptr_D_imag);
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_A_imag = args.batch_stride_A_imag;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_B_imag = args.batch_stride_B_imag;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_C_imag = args.batch_stride_C_imag;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
batch_stride_D_imag = args.batch_stride_D_imag;
|
||||
|
||||
output_op = args.epilogue;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmPlanarComplex() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) ||
|
||||
(args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) ||
|
||||
(args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) {
|
||||
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A_real = static_cast<ElementA *>(params.ptr_A_real);
|
||||
ElementA *ptr_A_imag = static_cast<ElementA *>(params.ptr_A_imag);
|
||||
|
||||
ElementB *ptr_B_real = static_cast<ElementB *>(params.ptr_B_real);
|
||||
ElementB *ptr_B_imag = static_cast<ElementB *>(params.ptr_B_imag);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm ||
|
||||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A;
|
||||
ptr_A_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A_imag;
|
||||
ptr_B_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B;
|
||||
ptr_B_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B_imag;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_A_real = static_cast<ElementA * const *>(params.ptr_A_real)[threadblock_tile_offset.k()];
|
||||
ptr_A_imag = static_cast<ElementA * const *>(params.ptr_A_imag)[threadblock_tile_offset.k()];
|
||||
ptr_B_real = static_cast<ElementB * const *>(params.ptr_B_real)[threadblock_tile_offset.k()];
|
||||
ptr_B_imag = static_cast<ElementB * const *>(params.ptr_B_imag)[threadblock_tile_offset.k()];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
offset_k,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A_real(
|
||||
params.params_A_real,
|
||||
ptr_A_real,
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorA iterator_A_imag(
|
||||
params.params_A_imag,
|
||||
ptr_A_imag,
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B_real(
|
||||
params.params_B_real,
|
||||
ptr_B_real,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
typename Mma::IteratorB iterator_B_imag(
|
||||
params.params_B_imag,
|
||||
ptr_B_imag,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A_real,
|
||||
iterator_A_imag,
|
||||
iterator_B_real,
|
||||
iterator_B_imag,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C_real = static_cast<ElementC *>(params.ptr_C_real);
|
||||
ElementC *ptr_C_imag = static_cast<ElementC *>(params.ptr_C_imag);
|
||||
ElementC *ptr_D_real = static_cast<ElementC *>(params.ptr_D_real);
|
||||
ElementC *ptr_D_imag = static_cast<ElementC *>(params.ptr_D_imag);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k());
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
ptr_D_real += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
ptr_D_imag += threadblock_tile_offset.k() * params.batch_stride_D_imag;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_C_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C;
|
||||
ptr_C_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C_imag;
|
||||
ptr_D_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D;
|
||||
ptr_D_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D_imag;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_C_real = static_cast<ElementC * const *>(params.ptr_C_real)[threadblock_tile_offset.k()];
|
||||
ptr_C_imag = static_cast<ElementC * const *>(params.ptr_C_imag)[threadblock_tile_offset.k()];
|
||||
ptr_D_real = static_cast<ElementC * const *>(params.ptr_D_real)[threadblock_tile_offset.k()];
|
||||
ptr_D_imag = static_cast<ElementC * const *>(params.ptr_D_imag)[threadblock_tile_offset.k()];
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C_real(
|
||||
params.params_C_real,
|
||||
ptr_C_real,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
typename Epilogue::OutputTileIterator iterator_C_imag(
|
||||
params.params_C_imag,
|
||||
ptr_C_imag,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D_real(
|
||||
params.params_D_real,
|
||||
ptr_D_real,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
typename Epilogue::OutputTileIterator iterator_D_imag(
|
||||
params.params_D_imag,
|
||||
ptr_D_imag,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
//
|
||||
// Construct epilogue
|
||||
//
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C_real = iterator_D_real;
|
||||
iterator_C_imag = iterator_D_imag;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(
|
||||
output_op,
|
||||
iterator_D_real,
|
||||
iterator_D_imag,
|
||||
accumulators,
|
||||
iterator_C_real,
|
||||
iterator_C_imag);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
583
include/cutlass/gemm/kernel/gemm_planar_complex_array.h
Normal file
583
include/cutlass/gemm/kernel/gemm_planar_complex_array.h
Normal file
@ -0,0 +1,583 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmPlanarComplexArray {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
using Operator = typename Mma::Operator;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementB>::value);
|
||||
|
||||
//
|
||||
// Additional types needed for reflection
|
||||
//
|
||||
|
||||
using ElementAccumulator = typename Mma::Policy::Operator::ElementC;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::Shape;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
//
|
||||
// Arguments structure
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
|
||||
int const *ptr_M;
|
||||
int const *ptr_N;
|
||||
int const *ptr_K;
|
||||
|
||||
void const * const * ptr_A_real;
|
||||
void const * const * ptr_A_imag;
|
||||
|
||||
void const * const * ptr_B_real;
|
||||
void const * const * ptr_B_imag;
|
||||
|
||||
void const * const * ptr_C_real;
|
||||
void const * const * ptr_C_imag;
|
||||
|
||||
void * const * ptr_D_real;
|
||||
void * const * ptr_D_imag;
|
||||
|
||||
int lda_real;
|
||||
int lda_imag;
|
||||
int ldb_real;
|
||||
int ldb_imag;
|
||||
int ldc_real;
|
||||
int ldc_imag;
|
||||
int ldd_real;
|
||||
int ldd_imag;
|
||||
|
||||
int64_t batch_stride_D; // unused
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kArray),
|
||||
batch_count(1),
|
||||
ptr_M(nullptr),
|
||||
ptr_N(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_A_real(nullptr),
|
||||
ptr_A_imag(nullptr),
|
||||
ptr_B_real(nullptr),
|
||||
ptr_B_imag(nullptr),
|
||||
ptr_C_real(nullptr),
|
||||
ptr_C_imag(nullptr),
|
||||
ptr_D_real(nullptr),
|
||||
ptr_D_imag(nullptr),
|
||||
batch_stride_D(0)
|
||||
{ }
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueOutputOp::Params epilogue,
|
||||
int const *ptr_M,
|
||||
int const *ptr_N,
|
||||
int const *ptr_K,
|
||||
void const * const * ptr_A_real,
|
||||
void const * const * ptr_A_imag,
|
||||
void const * const * ptr_B_real,
|
||||
void const * const * ptr_B_imag,
|
||||
void const * const * ptr_C_real,
|
||||
void const * const * ptr_C_imag,
|
||||
void * const * ptr_D_real,
|
||||
void * const * ptr_D_imag,
|
||||
int lda_real,
|
||||
int lda_imag,
|
||||
int ldb_real,
|
||||
int ldb_imag,
|
||||
int ldc_real,
|
||||
int ldc_imag,
|
||||
int ldd_real,
|
||||
int ldd_imag
|
||||
):
|
||||
mode(GemmUniversalMode::kArray),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue(epilogue),
|
||||
ptr_M(ptr_M),
|
||||
ptr_N(ptr_N),
|
||||
ptr_K(ptr_K),
|
||||
ptr_A_real(ptr_A_real),
|
||||
ptr_A_imag(ptr_A_imag),
|
||||
ptr_B_real(ptr_B_real),
|
||||
ptr_B_imag(ptr_B_imag),
|
||||
ptr_C_real(ptr_C_real),
|
||||
ptr_C_imag(ptr_C_imag),
|
||||
ptr_D_real(ptr_D_real),
|
||||
ptr_D_imag(ptr_D_imag),
|
||||
lda_real(lda_real),
|
||||
lda_imag(lda_imag),
|
||||
ldb_real(ldb_real),
|
||||
ldb_imag(ldb_imag),
|
||||
ldc_real(ldc_real),
|
||||
ldc_imag(ldc_imag),
|
||||
ldd_real(ldd_real),
|
||||
ldd_imag(ldd_imag),
|
||||
batch_stride_D(0) {
|
||||
|
||||
}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
Arguments args(*this);
|
||||
|
||||
std::swap(args.problem_size.m(), args.problem_size.n());
|
||||
std::swap(args.ptr_M, args.ptr_N);
|
||||
std::swap(args.ptr_A_real, args.ptr_B_real);
|
||||
std::swap(args.ptr_A_imag, args.ptr_B_imag);
|
||||
std::swap(args.lda_real, args.ldb_real);
|
||||
std::swap(args.lda_imag, args.ldb_imag);
|
||||
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
|
||||
typename Mma::IteratorA::Params params_A_real;
|
||||
typename Mma::IteratorA::Params params_A_imag;
|
||||
typename Mma::IteratorB::Params params_B_real;
|
||||
typename Mma::IteratorB::Params params_B_imag;
|
||||
typename Epilogue::OutputTileIterator::Params params_C_real;
|
||||
typename Epilogue::OutputTileIterator::Params params_C_imag;
|
||||
typename Epilogue::OutputTileIterator::Params params_D_real;
|
||||
typename Epilogue::OutputTileIterator::Params params_D_imag;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
int batch_count;
|
||||
|
||||
int const *ptr_M;
|
||||
int const *ptr_N;
|
||||
int const *ptr_K;
|
||||
|
||||
void const * const * ptr_A_real;
|
||||
void const * const * ptr_A_imag;
|
||||
void const * const * ptr_B_real;
|
||||
void const * const * ptr_B_imag;
|
||||
void const * const * ptr_C_real;
|
||||
void const * const * ptr_C_imag;
|
||||
void * const * ptr_D_real;
|
||||
void * const * ptr_D_imag;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
batch_count(0),
|
||||
ptr_M(nullptr),
|
||||
ptr_N(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_A_real(nullptr),
|
||||
ptr_A_imag(nullptr),
|
||||
ptr_B_real(nullptr),
|
||||
ptr_B_imag(nullptr),
|
||||
ptr_C_real(nullptr),
|
||||
ptr_C_imag(nullptr),
|
||||
ptr_D_real(nullptr),
|
||||
ptr_D_imag(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size = 0, // ignored
|
||||
void *workspace = nullptr // ignored
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
ptr_M(args.ptr_M),
|
||||
ptr_N(args.ptr_N),
|
||||
ptr_K(args.ptr_K),
|
||||
params_A_real(args.lda_real),
|
||||
params_A_imag(args.lda_imag),
|
||||
params_B_real(args.ldb_real),
|
||||
params_B_imag(args.ldb_imag),
|
||||
params_C_real(args.ldc_real),
|
||||
params_C_imag(args.ldc_imag),
|
||||
params_D_real(args.ldd_real),
|
||||
params_D_imag(args.ldd_imag),
|
||||
output_op(args.epilogue),
|
||||
batch_count(args.batch_count),
|
||||
ptr_A_real(args.ptr_A_real),
|
||||
ptr_A_imag(args.ptr_A_imag),
|
||||
ptr_B_real(args.ptr_B_real),
|
||||
ptr_B_imag(args.ptr_B_imag),
|
||||
ptr_C_real(args.ptr_C_real),
|
||||
ptr_C_imag(args.ptr_C_imag),
|
||||
ptr_D_real(args.ptr_D_real),
|
||||
ptr_D_imag(args.ptr_D_imag) {
|
||||
|
||||
}
|
||||
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
|
||||
ptr_M = args.ptr_M;
|
||||
ptr_N = args.ptr_N;
|
||||
ptr_K = args.ptr_K;
|
||||
|
||||
ptr_A_real = args.ptr_A_real;
|
||||
ptr_A_imag = args.ptr_A_imag;
|
||||
|
||||
ptr_B_real = args.ptr_B_real;
|
||||
ptr_B_imag = args.ptr_B_imag;
|
||||
|
||||
ptr_C_real = args.ptr_C_real;
|
||||
ptr_C_imag = args.ptr_C_imag;
|
||||
|
||||
ptr_D_real = args.ptr_D_real;
|
||||
ptr_D_imag = args.ptr_D_imag;
|
||||
|
||||
output_op = args.epilogue;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmPlanarComplexArray() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) ||
|
||||
(args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) ||
|
||||
(args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) {
|
||||
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
int batch_idx = threadblock_tile_offset.k();
|
||||
|
||||
int problem_size_m = params.problem_size.m();
|
||||
int problem_size_n = params.problem_size.n();
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A_real = static_cast<ElementA *>(const_cast<void *>(params.ptr_A_real[batch_idx]));
|
||||
ElementA *ptr_A_imag = static_cast<ElementA *>(const_cast<void *>(params.ptr_A_imag[batch_idx]));
|
||||
|
||||
ElementB *ptr_B_real = static_cast<ElementB *>(const_cast<void *>(params.ptr_B_real[batch_idx]));
|
||||
ElementB *ptr_B_imag = static_cast<ElementB *>(const_cast<void *>(params.ptr_B_imag[batch_idx]));
|
||||
|
||||
//
|
||||
// If pointers for problem sizes are specified, these are loaded from global memory
|
||||
//
|
||||
|
||||
if (params.ptr_M) {
|
||||
problem_size_m = params.ptr_M[batch_idx];
|
||||
}
|
||||
|
||||
if (params.ptr_N) {
|
||||
problem_size_n = params.ptr_N[batch_idx];
|
||||
}
|
||||
|
||||
if (params.ptr_K) {
|
||||
problem_size_k = params.ptr_K[batch_idx];
|
||||
}
|
||||
|
||||
int const kBlockCountM = (problem_size_m + Mma::Shape::kM - 1) / Mma::Shape::kM;
|
||||
int const kBlockCountN = (problem_size_n + Mma::Shape::kN - 1) / Mma::Shape::kN;
|
||||
|
||||
int const kGemmKIterations = (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
//
|
||||
// Each threadblock loops over the logical problem size which the kernel may have discovered
|
||||
// after the grid is launched.
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int block_m = threadblock_tile_offset.m();
|
||||
block_m < kBlockCountM;
|
||||
block_m += params.grid_tiled_shape.m()) {
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (int block_n = threadblock_tile_offset.n();
|
||||
block_n < kBlockCountN;
|
||||
block_n += params.grid_tiled_shape.n()) {
|
||||
|
||||
//
|
||||
// Compute indices within threadblock and warp.
|
||||
//
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Proceed with regular GEMM logic.
|
||||
//
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{ block_m * Mma::Shape::kM, 0};
|
||||
cutlass::MatrixCoord tb_offset_B{ 0, block_n * Mma::Shape::kN };
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A_real(
|
||||
params.params_A_real,
|
||||
ptr_A_real,
|
||||
{problem_size_m, problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorA iterator_A_imag(
|
||||
params.params_A_imag,
|
||||
ptr_A_imag,
|
||||
{problem_size_m, problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B_real(
|
||||
params.params_B_real,
|
||||
ptr_B_real,
|
||||
{problem_size_k, problem_size_n},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
typename Mma::IteratorB iterator_B_imag(
|
||||
params.params_B_imag,
|
||||
ptr_B_imag,
|
||||
{problem_size_k, problem_size_n},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
kGemmKIterations,
|
||||
accumulators,
|
||||
iterator_A_real,
|
||||
iterator_A_imag,
|
||||
iterator_B_real,
|
||||
iterator_B_imag,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
block_m * Mma::Shape::kM,
|
||||
block_n * Mma::Shape::kN
|
||||
);
|
||||
|
||||
ElementC *ptr_C_real = static_cast<ElementC *>(const_cast<void *>(params.ptr_C_real[batch_idx]));
|
||||
ElementC *ptr_C_imag = static_cast<ElementC *>(const_cast<void *>(params.ptr_C_imag[batch_idx]));
|
||||
ElementC *ptr_D_real = static_cast<ElementC *>(params.ptr_D_real[batch_idx]);
|
||||
ElementC *ptr_D_imag = static_cast<ElementC *>(params.ptr_D_imag[batch_idx]);
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C_real(
|
||||
params.params_C_real,
|
||||
ptr_C_real,
|
||||
{problem_size_m, problem_size_n},
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
typename Epilogue::OutputTileIterator iterator_C_imag(
|
||||
params.params_C_imag,
|
||||
ptr_C_imag,
|
||||
{problem_size_m, problem_size_n},
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D_real(
|
||||
params.params_D_real,
|
||||
ptr_D_real,
|
||||
{problem_size_m, problem_size_n},
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
typename Epilogue::OutputTileIterator iterator_D_imag(
|
||||
params.params_D_imag,
|
||||
ptr_D_imag,
|
||||
{problem_size_m, problem_size_n},
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
//
|
||||
// Construct epilogue
|
||||
//
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(
|
||||
output_op,
|
||||
iterator_D_real,
|
||||
iterator_D_imag,
|
||||
accumulators,
|
||||
iterator_C_real,
|
||||
iterator_C_imag);
|
||||
|
||||
|
||||
} // for block_n
|
||||
} // for block_m
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
537
include/cutlass/gemm/kernel/gemm_universal.h
Normal file
537
include/cutlass/gemm/kernel/gemm_universal.h
Normal file
@ -0,0 +1,537 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmUniversal {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::Shape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
|
||||
void const * ptr_A;
|
||||
void const * ptr_B;
|
||||
void const * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
int lda;
|
||||
int ldb;
|
||||
int ldc;
|
||||
int ldd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { }
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
typename EpilogueOutputOp::Params epilogue,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
void const * ptr_C,
|
||||
void * ptr_D,
|
||||
int64_t batch_stride_A,
|
||||
int64_t batch_stride_B,
|
||||
int64_t batch_stride_C,
|
||||
int64_t batch_stride_D,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int ldd
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) {
|
||||
|
||||
}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
Arguments args(*this);
|
||||
|
||||
std::swap(args.problem_size.m(), args.problem_size.n());
|
||||
std::swap(args.ptr_A, args.ptr_B);
|
||||
std::swap(args.lda, args.ldb);
|
||||
std::swap(args.batch_stride_A, args.batch_stride_B);
|
||||
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
int *semaphore;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
semaphore(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
params_A(args.lda),
|
||||
params_B(args.ldb),
|
||||
params_C(args.ldc),
|
||||
params_D(args.ldd),
|
||||
output_op(args.epilogue),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_C(const_cast<void *>(args.ptr_C)),
|
||||
ptr_D(args.ptr_D),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
|
||||
output_op = args.epilogue;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmUniversal() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size) {
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
|
||||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
|
||||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
|
||||
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm ||
|
||||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
offset_k,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
ptr_A,
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
ptr_B,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k());
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_C += threadblock_tile_offset.k() * params.batch_stride_C;
|
||||
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_C = static_cast<ElementC * const *>(params.ptr_C)[threadblock_tile_offset.k()];
|
||||
ptr_D = static_cast<ElementC * const *>(params.ptr_D)[threadblock_tile_offset.k()];
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.params_C,
|
||||
ptr_C,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.params_D,
|
||||
ptr_D,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(
|
||||
output_op,
|
||||
iterator_D,
|
||||
accumulators,
|
||||
iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
0
include/cutlass/gemm/kernel/gemv_batched_strided.h
Normal file → Executable file
0
include/cutlass/gemm/kernel/gemv_batched_strided.h
Normal file → Executable file
0
include/cutlass/gemm/threadblock/default_gemv_core.h
Normal file → Executable file
0
include/cutlass/gemm/threadblock/default_gemv_core.h
Normal file → Executable file
@ -33,6 +33,7 @@
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/wmma.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
@ -143,8 +144,9 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
layout::RowMajor, typename MmaCore::MmaPolicy>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass Simt)
|
||||
/// Specialization for row-major output (OperatorClass TensorOp)
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
@ -199,8 +201,8 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
|
||||
layout::RowMajor, typename MmaCore::MmaPolicy>;
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for column-major-interleaved output
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
@ -268,7 +270,9 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for SIMT IDP4A Kernels
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@ -326,6 +330,8 @@ struct DefaultMma<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
|
||||
layout::RowMajor, typename MmaCore::MmaPolicy>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
/// Specialization for Wmma TensorOp operator with 2 staged pipeline
|
||||
template <
|
||||
@ -384,6 +390,8 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
LayoutC, typename MmaCore::MmaPolicy>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for Wmma TensorOp operator with 1 staged pipeline
|
||||
template <
|
||||
///< Element type for A matrix operand
|
||||
@ -440,6 +448,7 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
|
||||
LayoutC, typename MmaCore::MmaPolicy>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#endif //CUTLASS_ARCH_WMMA_ENABLED
|
||||
|
||||
|
||||
@ -598,6 +598,8 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: column-major-interleave32
|
||||
|
||||
@ -0,0 +1,124 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_planar_complex.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/mma_planar_complex_pipelined.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Complex transformation on operand A
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Complex transformation on operand B
|
||||
ComplexTransform TransformB = ComplexTransform::kNone,
|
||||
/// Math operator tag (e.g. arch::OpMultiplyAdd)
|
||||
typename Operator = arch::OpMultiplyAdd
|
||||
>
|
||||
struct DefaultMmaPlanarComplexPipelined {
|
||||
|
||||
// Construct a planar complex variant from the real-valued variant
|
||||
using RealMma = typename DefaultMma<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
kAlignmentA,
|
||||
ElementB_,
|
||||
LayoutB_,
|
||||
kAlignmentB,
|
||||
ElementAccumulator_,
|
||||
LayoutC_,
|
||||
OperatorClass_,
|
||||
ArchTag_,
|
||||
ThreadblockShape_,
|
||||
WarpShape_,
|
||||
InstructionShape_,
|
||||
Stages,
|
||||
Operator
|
||||
>::ThreadblockMma;
|
||||
|
||||
using ThreadblockMma = MmaPlanarComplexPipelined<
|
||||
ThreadblockShape_,
|
||||
typename RealMma::IteratorA,
|
||||
typename RealMma::SmemIteratorA,
|
||||
typename RealMma::IteratorB,
|
||||
typename RealMma::SmemIteratorB,
|
||||
ElementAccumulator_,
|
||||
LayoutC_,
|
||||
typename RealMma::Policy,
|
||||
Stages,
|
||||
TransformA,
|
||||
TransformB
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
0
include/cutlass/gemm/threadblock/gemv.h
Normal file → Executable file
0
include/cutlass/gemm/threadblock/gemv.h
Normal file → Executable file
@ -118,6 +118,15 @@ public:
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
@ -256,8 +265,8 @@ public:
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
@ -299,7 +308,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
|
||||
warp_mma(accum, warp_frag_A[warp_mma_k % 2],
|
||||
warp_frag_B[warp_mma_k % 2], accum);
|
||||
}
|
||||
}
|
||||
|
||||
@ -311,3 +321,5 @@ public:
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
201
include/cutlass/gemm/threadblock/mma_planar_complex_base.h
Normal file
201
include/cutlass/gemm/threadblock/mma_planar_complex_base.h
Normal file
@ -0,0 +1,201 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaPlanarComplexBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages +
|
||||
Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Stride to the imaginary part of the A operand
|
||||
static int const kImaginaryStrideA = ShapeA::kCount;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB =
|
||||
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
/// Stride to the imaginary part of the A operand
|
||||
static int const kImaginaryStrideB = ShapeB::kCount;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount + kImaginaryStrideA> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount + kImaginaryStrideB> operand_B;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() {
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
MmaPlanarComplexBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
422
include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h
Normal file
422
include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h
Normal file
@ -0,0 +1,422 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/mma_planar_complex_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Transformation applied to A
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Transformation applied to B
|
||||
ComplexTransform TransformB = ComplexTransform::kNone
|
||||
>
|
||||
class MmaPlanarComplexPipelined :
|
||||
public MmaPlanarComplexBase<Shape_, Policy_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = MmaPlanarComplexBase<Shape_, Policy_, Stages>;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
/// Transformation applied to A
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
|
||||
/// Transformation applied to B
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = ArrayPlanarComplex<
|
||||
typename Policy::Operator::FragmentC::Element,
|
||||
Policy::Operator::FragmentC::kElements
|
||||
>;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
private:
|
||||
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
MmaPlanarComplexPipelined(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void warp_mma_planar_complex(
|
||||
Operator & warp_mma,
|
||||
FragmentC &accum,
|
||||
WarpFragmentA const & real_A,
|
||||
WarpFragmentA const & imag_A,
|
||||
WarpFragmentB const & real_B,
|
||||
WarpFragmentB const & imag_B) {
|
||||
|
||||
cutlass::negate<Array<typename WarpFragmentB::Element, WarpFragmentB::kElements>> neg_op_B;
|
||||
|
||||
WarpFragmentB neg_real_B = neg_op_B(real_B);
|
||||
WarpFragmentB neg_imag_B = neg_op_B(imag_B);
|
||||
|
||||
warp_mma(accum.real, real_A, real_B, accum.real);
|
||||
|
||||
if (kTransformB == ComplexTransform::kNone) {
|
||||
warp_mma(accum.imag, real_A, imag_B, accum.imag);
|
||||
}
|
||||
else {
|
||||
warp_mma(accum.imag, real_A, neg_imag_B, accum.imag);
|
||||
}
|
||||
|
||||
if (kTransformA == ComplexTransform::kNone) {
|
||||
warp_mma(accum.imag, imag_A, real_B, accum.imag);
|
||||
}
|
||||
else {
|
||||
warp_mma(accum.imag, imag_A, neg_real_B, accum.imag);
|
||||
}
|
||||
|
||||
if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) {
|
||||
warp_mma(accum.real, imag_A, imag_B, accum.real);
|
||||
}
|
||||
else {
|
||||
warp_mma(accum.real, imag_A, neg_imag_B, accum.real);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A_real,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A_imag,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B_real,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B_imag,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A_real;
|
||||
FragmentA tb_frag_A_imag;
|
||||
|
||||
FragmentB tb_frag_B_real;
|
||||
FragmentB tb_frag_B_imag;
|
||||
|
||||
tb_frag_A_real.clear();
|
||||
tb_frag_A_imag.clear();
|
||||
|
||||
tb_frag_B_real.clear();
|
||||
tb_frag_B_imag.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A_real.load(tb_frag_A_real);
|
||||
iterator_A_imag.load(tb_frag_A_imag);
|
||||
|
||||
iterator_B_real.load(tb_frag_B_real);
|
||||
iterator_B_imag.load(tb_frag_B_imag);
|
||||
|
||||
++iterator_A_real;
|
||||
++iterator_A_imag;
|
||||
|
||||
++iterator_B_real;
|
||||
++iterator_B_imag;
|
||||
|
||||
this->smem_iterator_A_.store(tb_frag_A_real);
|
||||
this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA);
|
||||
|
||||
this->smem_iterator_B_.store(tb_frag_B_real);
|
||||
this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB);
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_real_A[2];
|
||||
WarpFragmentA warp_frag_imag_A[2];
|
||||
|
||||
WarpFragmentB warp_frag_real_B[2];
|
||||
WarpFragmentB warp_frag_imag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_real_A[0]);
|
||||
this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA);
|
||||
|
||||
this->warp_tile_iterator_B_.load(warp_frag_real_B[0]);
|
||||
this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB);
|
||||
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations <= 1) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(tb_frag_A_real);
|
||||
this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA);
|
||||
|
||||
this->smem_iterator_B_.store(tb_frag_B_real);
|
||||
this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
++this->smem_iterator_A_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA);
|
||||
|
||||
this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A_real.load(tb_frag_A_real);
|
||||
iterator_A_imag.load(tb_frag_A_imag);
|
||||
|
||||
iterator_B_real.load(tb_frag_B_real);
|
||||
iterator_B_imag.load(tb_frag_B_imag);
|
||||
|
||||
++iterator_A_real;
|
||||
++iterator_A_imag;
|
||||
++iterator_B_real;
|
||||
++iterator_B_imag;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations <= 2) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
warp_mma_planar_complex(
|
||||
warp_mma,
|
||||
accum,
|
||||
warp_frag_real_A[warp_mma_k % 2],
|
||||
warp_frag_imag_A[warp_mma_k % 2],
|
||||
warp_frag_real_B[warp_mma_k % 2],
|
||||
warp_frag_imag_B[warp_mma_k % 2]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -168,7 +168,6 @@ public:
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
|
||||
@ -183,8 +182,9 @@ public:
|
||||
++iterator_B;
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
WarpFragmentA warp_frag_A;
|
||||
WarpFragmentB warp_frag_B;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
@ -193,7 +193,6 @@ public:
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
@ -203,7 +202,6 @@ public:
|
||||
this->smem_iterator_A_.store(tb_frag_A);
|
||||
this->smem_iterator_B_.store(tb_frag_B);
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
@ -216,16 +214,16 @@ public:
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k) % 2]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k) % 2]);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
|
||||
warp_mma(accum, warp_frag_A, warp_frag_B, accum);
|
||||
}
|
||||
|
||||
// Add negative offsets to return smem load iterators to the 'start' of the shared memory
|
||||
|
||||
@ -30,7 +30,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -141,6 +142,55 @@ struct GemmIdentityThreadblockSwizzle {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// A special version of GemmIdentityThreadblockSwizzle. See the choice of kTile below.
|
||||
template <typename LayoutA_, typename LayoutB_>
|
||||
struct GemmCohortThreadblockSwizzle
|
||||
{
|
||||
const int kTile =
|
||||
(platform::is_same<LayoutA_, cutlass::layout::RowMajor>::value ||
|
||||
platform::is_same<LayoutB_, cutlass::layout::ColumnMajor>::value)
|
||||
? 4
|
||||
: 1;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCohortThreadblockSwizzle() { }
|
||||
|
||||
/// Returns the shape of the problem in units of logical tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord get_tiled_shape(
|
||||
GemmCoord problem_size,
|
||||
GemmCoord tile_size,
|
||||
int split_k_slices) const {
|
||||
|
||||
return GemmCoord(
|
||||
(problem_size.m() + tile_size.m() - 1) / tile_size.m(),
|
||||
(problem_size.n() + tile_size.n() - 1) / tile_size.n(),
|
||||
split_k_slices);
|
||||
}
|
||||
|
||||
/// Computes CUDA grid dimensions given a size in units of logical tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
dim3 get_grid_shape(GemmCoord tiled_shape) const {
|
||||
return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k());
|
||||
}
|
||||
|
||||
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
|
||||
CUTLASS_DEVICE
|
||||
GemmCoord get_tile_offset() const {
|
||||
|
||||
int block_idx_x = RematerializeBlockIdxX();
|
||||
int block_idx_y = RematerializeBlockIdxY();
|
||||
|
||||
return GemmCoord{
|
||||
(block_idx_x / kTile),
|
||||
(block_idx_y * kTile) + (block_idx_x % kTile),
|
||||
RematerializeBlockIdxZ()
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Threadblock swizzling function for GEMMs
|
||||
struct GemmHorizontalThreadblockSwizzle {
|
||||
|
||||
@ -186,8 +236,8 @@ struct GemmBatchedIdentityThreadblockSwizzle {
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord get_tiled_shape(
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
GemmCoord tile_size) const {
|
||||
GemmCoord tile_size,
|
||||
int batch_count) const {
|
||||
|
||||
return GemmCoord(
|
||||
(problem_size.m() + tile_size.m() - 1) / tile_size.m(),
|
||||
@ -207,7 +257,7 @@ struct GemmBatchedIdentityThreadblockSwizzle {
|
||||
return GemmCoord{
|
||||
RematerializeBlockIdxX(),
|
||||
RematerializeBlockIdxY(),
|
||||
0
|
||||
RematerializeBlockIdxZ()
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -114,3 +114,6 @@ struct DefaultMmaTensorOp {
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -1,351 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
|
||||
Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename RealElementA,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Data type of B elements
|
||||
typename RealElementB,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename RealElementC,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
typename Policy_,
|
||||
/// Complex transform on A operand
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Complex transform on B operand
|
||||
ComplexTransform TransformB = ComplexTransform::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class MmaComplexTensorOp;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename RealElementA,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Data type of B elements
|
||||
typename RealElementB,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename RealElementC,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
typename Policy_,
|
||||
/// Complex transform on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex transform on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Used for partial specialization
|
||||
typename Enable
|
||||
>
|
||||
class MmaComplexTensorOp<
|
||||
Shape_,
|
||||
complex<RealElementA>,
|
||||
LayoutA_,
|
||||
complex<RealElementB>,
|
||||
LayoutB_,
|
||||
complex<RealElementC>,
|
||||
LayoutC_,
|
||||
Policy_,
|
||||
TransformA,
|
||||
TransformB,
|
||||
Enable> {
|
||||
public:
|
||||
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Data type of multiplicand A
|
||||
using ElementA = complex<RealElementA>;
|
||||
|
||||
/// Layout of multiplicand A
|
||||
using LayoutA = LayoutA_;
|
||||
|
||||
/// Data type of multiplicand B
|
||||
using ElementB = complex<RealElementB>;
|
||||
|
||||
/// Layout of multiplicand B
|
||||
using LayoutB = LayoutB_;
|
||||
|
||||
/// Data type of accumulator matrix C
|
||||
using ElementC = complex<RealElementC>;
|
||||
|
||||
/// Layout of accumulator matrix C
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
public:
|
||||
|
||||
/// Iterates over the A operand in memory
|
||||
using IteratorA = MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK>,
|
||||
Operand::kA,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,
|
||||
Policy::OpDelta::kRow,
|
||||
32,
|
||||
1
|
||||
>;
|
||||
|
||||
/// Storage for A tile
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>,
|
||||
Operand::kB,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
|
||||
Policy::OpDelta::kColumn,
|
||||
32,
|
||||
1
|
||||
>;
|
||||
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
|
||||
static_assert(
|
||||
!(Shape::kM % Policy::Operator::Shape::kM) &&
|
||||
!(Shape::kN % Policy::Operator::Shape::kN),
|
||||
"Shape of warp-level Mma must be divisible by operator shape.");
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<
|
||||
Shape::kM / Policy::Operator::Shape::kM,
|
||||
Shape::kN / Policy::Operator::Shape::kN
|
||||
>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaTensorOpAccumulatorTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kN>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
typename Policy::Operator::Shape,
|
||||
typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this
|
||||
/// storage arrangement is to be considered 'planar complex' in the sense that all real-valued
|
||||
/// parts are stored consecutively followed by all imaginary parts. This matches the structure
|
||||
/// of Tensor Cores which are always real-valued matrix multiplies.
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
static_assert(
|
||||
FragmentC::kElements == 2 * MmaIterations::kCount * Policy::Operator::FragmentC::kElements,
|
||||
"Unexpected planar complex fragment length.");
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Underlying real-valued matrix multiply operator (concept: arch::Mma)
|
||||
typename Policy::Operator mma;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
MmaComplexTensorOp() {}
|
||||
|
||||
/// Performs a warp-level matrix multiply-accumulate operation
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
FragmentC &D,
|
||||
FragmentA const &A,
|
||||
FragmentB const &B,
|
||||
FragmentC const &C) const {
|
||||
|
||||
// Alias types for underlying real-valued matrix multiply operator
|
||||
using MmaOperandA = typename Policy::Operator::FragmentA;
|
||||
using MmaOperandB = typename Policy::Operator::FragmentB;
|
||||
using MmaOperandC = typename Policy::Operator::FragmentC;
|
||||
|
||||
static_assert(MmaOperandA::kElements == 1,
|
||||
"This implementation only supports math instructions in which exactly one element is needed for the A operand."
|
||||
"We can geneneralize later.");
|
||||
|
||||
static_assert(MmaOperandB::kElements == 1,
|
||||
"This implementation only supports math instructions in which exactly one element is needed for the A operand."
|
||||
"We can geneneralize later.");
|
||||
|
||||
D = C;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m) {
|
||||
|
||||
// mma(accum.real(), a.real(), b.real(), accum.real());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n) {
|
||||
|
||||
// Pack operands together. This may result in actual MOVs
|
||||
MmaOperandA operand_A;
|
||||
MmaOperandB operand_B;
|
||||
|
||||
operand_A[0] = A[m].real();
|
||||
operand_B[0] = B[n].real();
|
||||
|
||||
// Real-valued accumulator part
|
||||
MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
|
||||
(m + n * MmaIterations::kRow);
|
||||
|
||||
mma(*accum, operand_A, operand_B, *accum);
|
||||
}
|
||||
|
||||
// mma(accum.imag(), a.real(), b.imag(), accum.imag());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = MmaIterations::kColumn - 1; n >= 0; --n) {
|
||||
|
||||
// Pack operands together. This may result in actual MOVs
|
||||
MmaOperandA operand_A;
|
||||
MmaOperandB operand_B;
|
||||
|
||||
operand_A[0] = A[m].real();
|
||||
operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag());
|
||||
|
||||
// Complex-valued accumulator part
|
||||
MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
|
||||
(m + n * MmaIterations::kRow) + MmaIterations::kCount;
|
||||
|
||||
mma(*accum, operand_A, operand_B, *accum);
|
||||
}
|
||||
|
||||
// mma(accum.real(), -a.imag(), b.imag(), accum.real())
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n) {
|
||||
|
||||
// Pack operands together. This may result in actual MOVs
|
||||
MmaOperandA operand_A;
|
||||
MmaOperandB operand_B;
|
||||
|
||||
// A imaginary part is intentionally negated
|
||||
operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? A[m].imag() : -A[m].imag());
|
||||
operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag());
|
||||
|
||||
// Complex-valued accumulator part
|
||||
MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
|
||||
(m + n * MmaIterations::kRow);
|
||||
|
||||
mma(*accum, operand_A, operand_B, *accum);
|
||||
}
|
||||
|
||||
// mma(accum.imag(), a.imag(), b.real(), accum.imag())
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = MmaIterations::kColumn - 1; n >= 0; --n) {
|
||||
|
||||
// Pack operands together. This may result in actual MOVs
|
||||
MmaOperandA operand_A;
|
||||
MmaOperandB operand_B;
|
||||
|
||||
operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? -A[m].imag() : A[m].imag());
|
||||
operand_B[0] = B[n].real();
|
||||
|
||||
// Real-valued accumulator part
|
||||
MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
|
||||
(m + n * MmaIterations::kRow) + MmaIterations::kCount;
|
||||
|
||||
mma(*accum, operand_A, operand_B, *accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// TODO - partial specializations of real*complex and complex*real
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
176
include/cutlass/gemm/warp/mma_planar_complex.h
Normal file
176
include/cutlass/gemm/warp/mma_planar_complex.h
Normal file
@ -0,0 +1,176 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing warp-level matrix multiply-accumulate operations.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/array_planar_complex.h"
|
||||
#include "cutlass/gemm/warp/tile_iterator_planar_complex.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Underlying real-valued warp-level matrix multiply
|
||||
typename Operator_,
|
||||
/// Transformation applied to A operand (typically folded into math instruction)
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Transformation applied to B operand (typically folded into math instruction)
|
||||
ComplexTransform TransformB = ComplexTransform::kNone
|
||||
>
|
||||
class MmaPlanarComplex {
|
||||
public:
|
||||
|
||||
/// Underlying real-valued warp-level matrix multiply
|
||||
using Operator = Operator_;
|
||||
|
||||
/// Shape of warp-level matrix multipy
|
||||
using Shape = typename Operator::Shape;
|
||||
|
||||
/// Transformation applied to A operand (typically folded into math instruction)
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
|
||||
/// Transformation applied to B operand (typically folded into math instruction)
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
/// Fragment of elements
|
||||
using FragmentA = ArrayPlanarComplex<typename Operator::ElementA, Operator::FragmentA::kElements>;
|
||||
|
||||
/// Iterator into planar complex
|
||||
using IteratorA = TileIteratorPlanarComplex<typename Operator::IteratorA>;
|
||||
|
||||
/// Layout in memory of the A operand
|
||||
using LayoutA = typename Operator::LayoutA;
|
||||
|
||||
using FragmentB = ArrayPlanarComplex<typename Operator::ElementB, Operator::FragmentB::kElements>;
|
||||
|
||||
/// Iterator into planar complex
|
||||
using IteratorB = TileIteratorPlanarComplex<typename Operator::IteratorB>;
|
||||
|
||||
/// Layout in memory of the B operand
|
||||
using LayoutB = typename Operator::LayoutB;
|
||||
|
||||
/// Tile iterator for accumulator
|
||||
using IteratorC = TileIteratorPlanarComplex<typename Operator::IteratorC>;
|
||||
|
||||
/// Accumulator fragment
|
||||
using FragmentC = ArrayPlanarComplex<typename Operator::ElementC, Operator::FragmentC::kElements>;
|
||||
|
||||
/// Layout of accumulator fragment in memory
|
||||
using LayoutC = typename Operator::LayoutC;
|
||||
|
||||
private:
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<
|
||||
Operator::Shape::kM / Operator::Policy::Operator::Shape::kM,
|
||||
Operator::Shape::kN / Operator::Policy::Operator::Shape::kN
|
||||
>;
|
||||
|
||||
public:
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
MmaPlanarComplex() {}
|
||||
|
||||
/// Performs a warp-level matrix multiply-accumulate operation
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
FragmentC &D,
|
||||
FragmentA const &A_in,
|
||||
FragmentB const &B_in,
|
||||
FragmentC const &C) const {
|
||||
|
||||
D.real = C.real;
|
||||
D.imag = C.imag;
|
||||
|
||||
//
|
||||
// Transform fragments based on conjugate operations.
|
||||
//
|
||||
|
||||
negate<typename FragmentA::ArrayReal> neg_A;
|
||||
|
||||
FragmentA frag_A;
|
||||
frag_A.real = A_in.real;
|
||||
|
||||
if (kTransformA == ComplexTransform::kConjugate) {
|
||||
frag_A.imag = neg_A(frag_A.imag);
|
||||
}
|
||||
else {
|
||||
frag_A.imag = frag_A.imag;
|
||||
}
|
||||
|
||||
FragmentB frag_B;
|
||||
frag_B.real = B_in.real;
|
||||
|
||||
if (kTransformB == ComplexTransform::kConjugate) {
|
||||
negate<typename FragmentB::ArrayReal> neg;
|
||||
frag_B.imag = neg(frag_B.imag);
|
||||
}
|
||||
else {
|
||||
frag_B.imag = frag_B.imag;
|
||||
}
|
||||
|
||||
//
|
||||
// Accumulated real-valued matrix multiplies
|
||||
//
|
||||
|
||||
Operator real_mma;
|
||||
|
||||
// D.i += A.i * B.r
|
||||
real_mma(D.imag, frag_A.imag, frag_B.real, D.imag);
|
||||
|
||||
// D.r += A.r * B.r
|
||||
real_mma(D.real, frag_A.real, frag_B.real, D.real);
|
||||
|
||||
// D.i += A.r * B.i
|
||||
real_mma(D.imag, frag_A.real, frag_B.imag, D.imag);
|
||||
|
||||
// D.r += -A.i * B.i
|
||||
frag_A.imag = neg_A(frag_A.imag);
|
||||
real_mma(D.real, frag_A.imag, frag_B.imag, D.real);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -100,6 +100,16 @@ public:
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassSimt;
|
||||
|
||||
/// Hard-coded for now
|
||||
using ArchTag = arch::Sm50;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Layout of threads
|
||||
using ThreadLayoutA = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA >::value,
|
||||
layout::ColumnMajor,
|
||||
typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value,
|
||||
@ -153,6 +163,9 @@ public:
|
||||
/// Storage for A tile
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA = FragmentA;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaSimtTileIterator<
|
||||
MatrixShape<Policy::LaneMmaShape::kK, Shape::kN>,
|
||||
@ -167,6 +180,9 @@ public:
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentB = FragmentB;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaSimtTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kN>,
|
||||
@ -201,6 +217,15 @@ public:
|
||||
|
||||
mma(d, a, b, c);
|
||||
}
|
||||
|
||||
/// Transform the mma operands to the required types
|
||||
CUTLASS_DEVICE
|
||||
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
|
||||
FragmentA const &A, FragmentB const &B) const {
|
||||
//TODO: Implement this
|
||||
dst_A = A;
|
||||
dst_B = B;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -31,7 +31,9 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
@ -51,6 +53,60 @@ namespace warp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T, typename S, int N, FloatRoundStyle Round>
|
||||
struct ConvertAndPack {
|
||||
|
||||
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<S, N> const &source) {
|
||||
Converter converter;
|
||||
|
||||
return converter(source);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N, FloatRoundStyle Round>
|
||||
struct ConvertAndPack<T, T, N, Round> {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &source) {
|
||||
return source;
|
||||
}
|
||||
};
|
||||
|
||||
template <int N, FloatRoundStyle Round>
|
||||
struct ConvertAndPack<half_t, float, N, Round> {
|
||||
|
||||
using Converter = NumericArrayConverter<half_t, float, N, Round>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(Array<float, N> const &source) {
|
||||
Converter converter;
|
||||
|
||||
Array<float, N> tmp;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc));
|
||||
tmp[i] = source[idx];
|
||||
}
|
||||
|
||||
return converter(tmp);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
@ -105,9 +161,18 @@ public:
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Architecture tag from underlying instruction
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
@ -128,6 +193,10 @@ public:
|
||||
/// Storage for A tile
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA =
|
||||
Array<typename Policy::Operator::ElementA, FragmentA::kElements>;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
|
||||
@ -137,6 +206,10 @@ public:
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Storage for transformed B tile
|
||||
using TransformedFragmentB =
|
||||
Array<typename Policy::Operator::ElementB, FragmentB::kElements>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaTensorOpAccumulatorTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
|
||||
@ -179,8 +252,8 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
FragmentC &D,
|
||||
FragmentA const &A,
|
||||
FragmentB const &B,
|
||||
TransformedFragmentA const &A,
|
||||
TransformedFragmentB const &B,
|
||||
FragmentC const &C,
|
||||
int const &partitionN_idx = 0) const {
|
||||
|
||||
@ -221,6 +294,44 @@ public:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Transform the mma operands to the required types
|
||||
CUTLASS_DEVICE
|
||||
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
|
||||
FragmentA const &A, FragmentB const &B) const {
|
||||
bool midway_depstage =
|
||||
!(platform::is_same<typename Policy::Operator::ElementA,
|
||||
ElementA>::value &&
|
||||
platform::is_same<typename Policy::Operator::ElementB,
|
||||
ElementB>::value);
|
||||
|
||||
//
|
||||
// Define conversions from source type to instruction type
|
||||
//
|
||||
FloatRoundStyle const kRoundA =
|
||||
PreferredRoundingMode<typename Policy::Operator::ElementA,
|
||||
ElementA>::kRound;
|
||||
FloatRoundStyle const kRoundB =
|
||||
PreferredRoundingMode<typename Policy::Operator::ElementB,
|
||||
ElementB>::kRound;
|
||||
detail::ConvertAndPack<typename Policy::Operator::ElementA, ElementA,
|
||||
FragmentA::kElements, kRoundA>
|
||||
convert_A;
|
||||
NumericArrayConverter<typename Policy::Operator::ElementB, ElementB,
|
||||
FragmentB::kElements / 2, kRoundB>
|
||||
convert_B;
|
||||
Array<ElementB, FragmentB::kElements / 2> const *ptr_B =
|
||||
reinterpret_cast<Array<ElementB, FragmentB::kElements / 2> const *>(&B);
|
||||
Array<typename Policy::Operator::ElementB, FragmentB::kElements / 2> *
|
||||
ptr_dst_B = reinterpret_cast<Array<typename Policy::Operator::ElementB,
|
||||
FragmentB::kElements / 2> *>(&dst_B);
|
||||
|
||||
dst_A = convert_A(A);
|
||||
|
||||
ptr_dst_B[0] = convert_B(ptr_B[0]);
|
||||
ptr_dst_B[1] = convert_B(ptr_B[1]);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -228,3 +339,5 @@ public:
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -103,6 +103,15 @@ public:
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Architecture tag
|
||||
using ArchTag = arch::Sm70;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
|
||||
@ -199,7 +199,8 @@ public:
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
using Fragment =
|
||||
Array<Element, Shape::kContiguous * InstructionShape::kStrided / kThreads>;
|
||||
|
||||
private:
|
||||
|
||||
@ -516,7 +517,7 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
private:
|
||||
|
||||
@ -747,7 +748,7 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
private:
|
||||
|
||||
@ -1023,7 +1024,8 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
using Fragment = Array<Element, Shape::kStrided *
|
||||
InstructionShape::kContiguous / kThreads>;
|
||||
|
||||
private:
|
||||
|
||||
@ -1151,7 +1153,8 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile;
|
||||
|
||||
byte_offset_ ^= k_groups_delta * sizeof_bits<Element>::value *
|
||||
Layout::kElementsPerAccess / 8;
|
||||
Layout::kElementsPerAccess *
|
||||
Policy::LdsmShape::kContiguous / 8;
|
||||
pointer_ +=
|
||||
tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor +
|
||||
whole_tiles * stride_ / sections_;
|
||||
@ -1406,7 +1409,7 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
private:
|
||||
/// Underlying tile iterator
|
||||
@ -1636,7 +1639,7 @@ class MmaTensorOpMultiplicandTileIterator<
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
private:
|
||||
/// Underlying tile iterator
|
||||
|
||||
@ -165,7 +165,8 @@ public:
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
|
||||
using Fragment = Array<Element, Shape::kContiguous *
|
||||
InstructionShape::kStrided / kThreads * 2>;
|
||||
|
||||
private:
|
||||
|
||||
@ -473,7 +474,8 @@ public:
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile, needs on more time number of registers
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
|
||||
using Fragment = Array<Element, Shape::kContiguous *
|
||||
InstructionShape::kStrided / kThreads * 2>;
|
||||
|
||||
private:
|
||||
|
||||
@ -738,7 +740,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
private:
|
||||
|
||||
@ -962,7 +964,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
private:
|
||||
|
||||
@ -1557,7 +1559,9 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
|
||||
using Fragment =
|
||||
Array<Element,
|
||||
Shape::kStrided * InstructionShape::kContiguous / kThreads * 2>;
|
||||
|
||||
private:
|
||||
|
||||
@ -1869,7 +1873,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
private:
|
||||
/// Underlying tile iterator
|
||||
@ -2097,7 +2101,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
//
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
private:
|
||||
/// Underlying tile iterator
|
||||
|
||||
@ -106,6 +106,15 @@ public:
|
||||
/// Shape of the warp in units of thread (concept: MmaTensorOpPolicy)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Underlying architecture tag
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
@ -193,7 +202,6 @@ public:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
244
include/cutlass/gemm/warp/tile_iterator_planar_complex.h
Normal file
244
include/cutlass/gemm/warp/tile_iterator_planar_complex.h
Normal file
@ -0,0 +1,244 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing warp-level matrix multiply-accumulate operations.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/array_planar_complex.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileIterator_>
|
||||
class TileIteratorPlanarComplex {
|
||||
public:
|
||||
|
||||
/// Underlying iterator over real-valued tiles
|
||||
using TileIterator = TileIterator_;
|
||||
|
||||
/// Underlying element type
|
||||
using Element = typename TileIterator::Element;
|
||||
|
||||
/// Underlying layout type
|
||||
using Layout = typename TileIterator::Layout;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = typename TileIterator::TensorRef;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Planar complex fragment
|
||||
using Fragment = ArrayPlanarComplex<Element, TileIterator::Fragment::kElements>;
|
||||
|
||||
public:
|
||||
|
||||
/// Underlying tile iterator
|
||||
TileIterator tile_iterator_;
|
||||
|
||||
/// Offset (in units of bytes) to the imaginary part of the planar complex matrix
|
||||
LongIndex imaginary_offset_;
|
||||
|
||||
public:
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorPlanarComplex(): imaginary_offset_(0) { }
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_DEVICE
|
||||
TileIteratorPlanarComplex(
|
||||
TensorRef const &ref,
|
||||
int lane_id,
|
||||
LongIndex imaginary_offset
|
||||
):
|
||||
tile_iterator_(ref, lane_id),
|
||||
imaginary_offset_((imaginary_offset * sizeof_bits<Element>::value) / 8) { }
|
||||
|
||||
|
||||
/// Adds a pointer offset to internal pointer(s) to advance through memory
|
||||
CUTLASS_DEVICE
|
||||
TileIteratorPlanarComplex &add_pointer_offset(LongIndex offset) {
|
||||
|
||||
tile_iterator_.add_pointer_offset(offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorPlanarComplex &add_tile_offset(TensorCoord const &tile_offset) {
|
||||
|
||||
tile_iterator_.add_tile_offset(tile_offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_DEVICE
|
||||
TileIteratorPlanarComplex & operator++() {
|
||||
++tile_iterator_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
//
|
||||
// WIP
|
||||
//
|
||||
|
||||
/// Advances the iterator along the opposite of the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorPlanarComplex & operator--() {
|
||||
--tile_iterator_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
TileIteratorPlanarComplex & operator+=(TensorCoord const &tile_offset) {
|
||||
tile_iterator_.add_tile_offset(tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_DEVICE
|
||||
TileIteratorPlanarComplex & operator-=(TensorCoord const &tile_offset) {
|
||||
tile_iterator_.add_tile_offset(-tile_offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
tile_iterator_.load_with_byte_offset(frag.real, 0);
|
||||
tile_iterator_.load_with_byte_offset(frag.imag, imaginary_offset_);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset in units of bytes
|
||||
Index byte_offset) const {
|
||||
|
||||
tile_iterator_.load_with_byte_offset(frag.real, byte_offset);
|
||||
tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with additional logical offset
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a linear offset
|
||||
Index pointer_offset) const {
|
||||
|
||||
Index byte_offset = (pointer_offset * sizeof_bits<Element>::value)/8;
|
||||
|
||||
tile_iterator_.load_with_byte_offset(frag.real, byte_offset);
|
||||
tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset) const {
|
||||
|
||||
tile_iterator_.load_with_byte_offset(frag.real, tile_offset, 0);
|
||||
tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, imaginary_offset_);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index pointer_offset) const {
|
||||
|
||||
Index byte_offset = (pointer_offset * sizeof_bits<Element>::value)/8;
|
||||
|
||||
tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset);
|
||||
tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset + imaginary_offset_);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory with logical offset in units of whole tiles.
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(
|
||||
/// fragment to load from the tensor
|
||||
Fragment &frag,
|
||||
/// loads a tile with a logical offset in units of whole tiles
|
||||
TensorCoord const &tile_offset,
|
||||
/// loads a tile with a logical offset AND a pointer offset
|
||||
Index byte_offset) const {
|
||||
|
||||
tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset);
|
||||
tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, byte_offset + imaginary_offset_);
|
||||
}
|
||||
|
||||
/// Notify the iterator which k-group it is currently pointing to.
|
||||
///
|
||||
/// This does not advance the iterator. Rather, it overrides its internal
|
||||
/// tracking with constant-valued k-group index to enable the compiler to
|
||||
/// fold constants and achieve more efficient code.
|
||||
///
|
||||
/// This is used by some nontrivial permuted layouts.
|
||||
CUTLASS_DEVICE
|
||||
void set_kgroup_index(int k_group) {
|
||||
tile_iterator_.set_kgroup_index(k_group);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -78,6 +78,10 @@ enum
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#if defined(__i386__) || defined(__x86_64__)
|
||||
#include <intrin.h>
|
||||
#endif
|
||||
|
||||
#define F16C_ROUND_NEAREST 0
|
||||
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
@ -110,9 +114,51 @@ __inline unsigned short _cvtss_sh (float __F, const int) {
|
||||
|
||||
// Linux
|
||||
#include <x86intrin.h>
|
||||
|
||||
#if defined(__i386__) || defined(__x86_64__)
|
||||
#include <cpuid.h>
|
||||
#endif
|
||||
|
||||
#define F16C_ROUND_NEAREST (_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)
|
||||
|
||||
#endif
|
||||
#endif // _MSC_VER
|
||||
|
||||
class CpuId {
|
||||
|
||||
bool f16c_enabled;
|
||||
|
||||
CpuId() {
|
||||
#if defined(__i386__) || defined(__x86_64__)
|
||||
#if defined(_MSC_VER)
|
||||
int exx[4];
|
||||
|
||||
__cpuid (exx, 1);
|
||||
f16c_enabled = exx[2] & 0x20000000;
|
||||
|
||||
#else
|
||||
// GCC / Clang
|
||||
int eax, ebx, ecx, edx;
|
||||
|
||||
__cpuid (1 , eax, ebx, ecx, edx);
|
||||
f16c_enabled = ecx & 0x20000000;
|
||||
#endif
|
||||
#else
|
||||
// Arm / PowerPC etc.
|
||||
f16c_enabled = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
bool is_f16c_supported() const {
|
||||
return f16c_enabled;
|
||||
}
|
||||
|
||||
static const CpuId& instance() {
|
||||
static CpuId cpu;
|
||||
return cpu;
|
||||
}
|
||||
};
|
||||
#endif // !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -154,10 +200,15 @@ struct alignas(2) half_t {
|
||||
static half_t convert(float const& flt) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
||||
return half_t(__float2half_rn(flt));
|
||||
#elif !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C
|
||||
unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST);
|
||||
return bitcast(u);
|
||||
#else
|
||||
|
||||
#if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C
|
||||
if( CpuId::instance().is_f16c_supported() ) {
|
||||
unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST);
|
||||
return bitcast(u);
|
||||
}
|
||||
#endif
|
||||
|
||||
// software implementation rounds toward nearest even
|
||||
unsigned const& s = reinterpret_cast<unsigned const &>(flt);
|
||||
uint16_t sign = uint16_t((s >> 16) & 0x8000);
|
||||
@ -248,10 +299,15 @@ struct alignas(2) half_t {
|
||||
static float convert(half_t const& x) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
||||
return __half2float(x.to_half());
|
||||
#elif !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C
|
||||
unsigned short u = x.storage;
|
||||
return _cvtsh_ss(u);
|
||||
#else
|
||||
|
||||
#if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C
|
||||
if( CpuId::instance().is_f16c_supported() ) {
|
||||
unsigned short u = x.storage;
|
||||
return _cvtsh_ss(u);
|
||||
}
|
||||
#endif
|
||||
|
||||
uint16_t const &h = x.storage;
|
||||
int sign = ((h >> 15) & 1);
|
||||
int exp = ((h >> 10) & 0x1f);
|
||||
|
||||
@ -29,7 +29,11 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <cstdint>
|
||||
#endif
|
||||
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
@ -48,7 +52,7 @@ struct integer_subbyte {
|
||||
static bool const kSigned = Signed;
|
||||
|
||||
/// External type
|
||||
using T = typename std::conditional<kSigned, int, unsigned>::type;
|
||||
using T = typename platform::conditional<kSigned, int, unsigned>::type;
|
||||
|
||||
/// Storage type
|
||||
using Storage = uint8_t;
|
||||
|
||||
@ -38,7 +38,6 @@
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/tensor_nhwc.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
|
||||
#include "cutlass/layout/tensor_op_multiplicand_sm70.h"
|
||||
|
||||
@ -98,7 +98,7 @@ public:
|
||||
/// Assumes coordinate has convention (row, column)
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(MatrixCoord const &coord) const {
|
||||
return coord.row() * stride_[0] + coord.column();
|
||||
return LongIndex(coord.row()) * LongIndex(stride_[0]) + coord.column();
|
||||
}
|
||||
|
||||
/// Inverse of layout function, mapping linear offset to logical coordinate
|
||||
@ -134,7 +134,7 @@ public:
|
||||
/// Compute the number of contiguous elements needed to store a tensor with the given size
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex capacity(MatrixCoord const &extent) const {
|
||||
return extent.row() * stride_[0];
|
||||
return LongIndex(extent.row()) * LongIndex(stride_[0]);
|
||||
}
|
||||
};
|
||||
|
||||
@ -191,7 +191,7 @@ public:
|
||||
/// Assumes coordinate has convention (row, column)
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(MatrixCoord const &coord) const {
|
||||
return coord.row() + coord.column() * stride_[0];
|
||||
return LongIndex(coord.column()) * LongIndex(stride_[0]) + coord.row();
|
||||
}
|
||||
|
||||
/// Inverse of layout function, mapping linear offset to logical coordinate
|
||||
@ -227,7 +227,7 @@ public:
|
||||
/// Compute the number of contiguous elements needed to store a tensor with the given size
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex capacity(MatrixCoord const &extent) const {
|
||||
return extent.column() * stride_[0];
|
||||
return LongIndex(extent.column()) * LongIndex(stride_[0]);
|
||||
}
|
||||
};
|
||||
|
||||
@ -290,7 +290,7 @@ public:
|
||||
LongIndex operator()(MatrixCoord const &coord) const {
|
||||
Index row_major = coord.row() / kInterleave;
|
||||
Index row_minor = coord.row() % kInterleave;
|
||||
return row_major * stride_[0] + coord.column() * kInterleave + row_minor;
|
||||
return LongIndex(row_major) * LongIndex(stride_[0]) + LongIndex(coord.column()) * kInterleave + row_minor;
|
||||
}
|
||||
|
||||
/// Inverse of layout function, mapping linear offset to logical coordinate
|
||||
@ -397,7 +397,7 @@ public:
|
||||
LongIndex operator()(MatrixCoord const &coord) const {
|
||||
Index column_major = coord.column() / kInterleave;
|
||||
Index column_minor = coord.column() % kInterleave;
|
||||
return column_major * stride_[0] + coord.row() * kInterleave + column_minor;
|
||||
return LongIndex(column_major) * LongIndex(stride_[0]) + LongIndex(coord.row()) * kInterleave + column_minor;
|
||||
}
|
||||
|
||||
/// Inverse of layout function, mapping linear offset to logical coordinate
|
||||
|
||||
@ -116,6 +116,11 @@ public:
|
||||
return PitchLinearCoord(Base::operator-(b));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PitchLinearCoord operator-() const {
|
||||
return PitchLinearCoord(-at(0), -at(1));
|
||||
}
|
||||
|
||||
/// Element-wise multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
PitchLinearCoord operator*(Base const& b) const {
|
||||
@ -211,7 +216,7 @@ public:
|
||||
/// Assumes coordinate has convention (contiguous, strided)
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(TensorCoord const &coord) const {
|
||||
return coord.contiguous() + coord.strided() * stride_[0];
|
||||
return LongIndex(coord.contiguous()) + LongIndex(coord.strided()) * LongIndex(stride_[0]);
|
||||
}
|
||||
|
||||
/// Returns the logical coordinate given an offset.
|
||||
|
||||
@ -33,7 +33,11 @@
|
||||
defined in cutlass/tensor_ref.h.
|
||||
*/
|
||||
#pragma once
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include "assert.h"
|
||||
#endif
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
@ -45,7 +45,7 @@ enum class FloatRoundStyle {
|
||||
round_toward_zero, ///< round toward zero
|
||||
round_to_nearest, ///< round to nearest even
|
||||
round_toward_infinity, ///< round toward infinity
|
||||
round_toward_neg_infinity, ///< round toward negative infinity
|
||||
round_toward_neg_infinity ///< round toward negative infinity
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -268,8 +268,7 @@ struct NumericConverterClamp {
|
||||
result_type const kClamp_min = -kClamp_max - 1;
|
||||
bool is_int_min = !(s > kClamp_min);
|
||||
bool is_int_max = !(s < kClamp_max);
|
||||
|
||||
return (is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s)));
|
||||
return is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -295,15 +294,15 @@ struct NumericConverterClamp<T, float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & s) {
|
||||
|
||||
NumericConverter<result_type, source_type> convert_op;
|
||||
NumericConverter<result_type, double> convert_op;
|
||||
|
||||
float kClamp_max = float((1 << (sizeof_bits<result_type>::value - 1)) - 1);
|
||||
float kClamp_min = -kClamp_max - 1;
|
||||
double kClamp_max = double((1U << (sizeof_bits<result_type>::value - 1)) - 1);
|
||||
double kClamp_min = -kClamp_max - 1;
|
||||
|
||||
float source = s;
|
||||
double source = s;
|
||||
|
||||
source = fmaxf(source, kClamp_min);
|
||||
source = fminf(source, kClamp_max);
|
||||
source = fmax(source, kClamp_min);
|
||||
source = fmin(source, kClamp_max);
|
||||
|
||||
return convert_op(source);
|
||||
}
|
||||
@ -353,6 +352,24 @@ struct NumericArrayConverter {
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int N,
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
struct NumericArrayConverter<T, T, N, Round> {
|
||||
|
||||
using result_type = Array<T, N>;
|
||||
using source_type = Array<T, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return s;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Array<half, 2> <= Array<float, 2>, round to nearest
|
||||
@ -498,10 +515,15 @@ struct NumericArrayConverter<float, half_t, N, Round> {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conditional guards to enable partial specialization for packed integers
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
|
||||
((__CUDACC_VER_MAJOR__ > 10) || \
|
||||
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
||||
|
||||
/// Partial specialization for Array<int8_t, 1> <= Array<int, 1>
|
||||
template <
|
||||
@ -625,12 +647,13 @@ struct NumericArrayConverter<int8_t, int, N, Round> {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // Conditional guards to enable partial specialization for packed integers
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \
|
||||
((__CUDACC_VER_MAJOR__ > 10) || \
|
||||
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
||||
|
||||
/// Partial specialization for Array<int4b_t, 8> <= Array<int, 8>
|
||||
template <
|
||||
@ -707,4 +730,121 @@ struct NumericArrayConverter<int4b_t, int, N, Round> {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// FastNumericArrayConverter only works when the source is within center range.
|
||||
/// Conversion operator for Array
|
||||
template <typename T, typename S, int N,
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
|
||||
struct FastNumericArrayConverter {
|
||||
using result_type = Array<T, N>;
|
||||
using source_type = Array<S, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const &s) {
|
||||
result_type result;
|
||||
NumericArrayConverter<T, S, N, Round> convert_;
|
||||
|
||||
return convert_(s);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) { return convert(s); }
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<float> <= Array<int>
|
||||
template <typename T, int N, FloatRoundStyle Round>
|
||||
struct FastNumericArrayConverter<float, T, N, Round> {
|
||||
using result_type = Array<float, N>;
|
||||
using source_type = Array<T, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const &source) {
|
||||
result_type result;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int tmp = source[i] + 1262485504 /*0x4B400000*/;
|
||||
result[i] = reinterpret_cast<float const &>(tmp) - 12582912.0f;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) { return convert(s); }
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<int8_t, 4> <= Array<float, 4>
|
||||
template <FloatRoundStyle Round>
|
||||
struct FastNumericArrayConverter<int8_t, float, 4, Round> {
|
||||
using result_type = Array<int8_t, 4>;
|
||||
using source_type = Array<float, 4>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const &source) {
|
||||
Array<int32_t, 4> result;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
float tmp = source[i] + 12582912.0f;
|
||||
result[i] = reinterpret_cast<int32_t const &>(tmp);
|
||||
}
|
||||
|
||||
result[0] = __byte_perm(result[0], result[1], 0x40);
|
||||
result[2] = __byte_perm(result[2], result[3], 0x40);
|
||||
result[0] = __byte_perm(result[0], result[2], 0x5410);
|
||||
|
||||
return reinterpret_cast<result_type const &>(result[0]);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) { return convert(s); }
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<int8_t> <= Array<float>
|
||||
template <int N, FloatRoundStyle Round>
|
||||
struct FastNumericArrayConverter<int8_t, float, N, Round> {
|
||||
static_assert(!(N % 4), "N must be multiple of 4.");
|
||||
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<float, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const &source) {
|
||||
FastNumericArrayConverter<int8_t, float, 4, Round> convert_vector_;
|
||||
|
||||
result_type result;
|
||||
|
||||
Array<int8_t, 4> *result_ptr =
|
||||
reinterpret_cast<Array<int8_t, 4> *>(&result);
|
||||
Array<float, 4> const *source_ptr =
|
||||
reinterpret_cast<Array<float, 4> const *>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 4; ++i) {
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) { return convert(s); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines preferred rounding mode for a pair of types
|
||||
template <typename T, typename S>
|
||||
struct PreferredRoundingMode {
|
||||
static FloatRoundStyle const kRound = FloatRoundStyle::round_to_nearest;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -28,7 +28,11 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
|
||||
@ -95,7 +95,11 @@
|
||||
// Dependencies
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
@ -28,10 +28,13 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
|
||||
@ -30,6 +30,10 @@ namespace cutlass {
|
||||
template <typename T>
|
||||
struct RealType {
|
||||
using Type = T;
|
||||
|
||||
static T from_real(double x) {
|
||||
return static_cast<T>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -38,4 +42,5 @@ static T from_real(double r) {
|
||||
return T(r);
|
||||
}
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -85,7 +85,13 @@ struct ReduceAdd {
|
||||
|
||||
plus<FragmentAccumulator> op;
|
||||
|
||||
return op(accumulator, element);
|
||||
NumericArrayConverter<
|
||||
ElementAccumulator,
|
||||
Element,
|
||||
kCount,
|
||||
PreferredRoundingMode<ElementAccumulator, Element>::kRound> converter;
|
||||
|
||||
return op(accumulator, converter(element));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -66,8 +66,9 @@ public:
|
||||
/// Permit fetching the synchronization mechanism early
|
||||
CUTLASS_DEVICE
|
||||
void fetch() {
|
||||
|
||||
asm volatile ("ld.global.cg.s32 %0, [%1];\n" : "=r"(state) : "l"(lock));
|
||||
if (wait_thread) {
|
||||
asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the internal state
|
||||
@ -80,14 +81,8 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void wait(int status = 0) {
|
||||
|
||||
if (wait_thread) {
|
||||
while (state != status) {
|
||||
|
||||
fetch();
|
||||
|
||||
__syncwarp(0x01);
|
||||
|
||||
};
|
||||
while( __syncthreads_and(state != status) ) {
|
||||
fetch();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@ -99,8 +94,7 @@ public:
|
||||
__syncthreads();
|
||||
|
||||
if (wait_thread) {
|
||||
|
||||
asm volatile ("st.global.cg.s32 [%0], %1;\n" : : "l"(lock), "r"(status));
|
||||
asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -379,6 +379,7 @@ template <
|
||||
typename Element,
|
||||
typename Layout
|
||||
>
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool TensorRef_aligned(TensorRef<Element, Layout> const &ref, int alignment) {
|
||||
|
||||
int const kStrideRank = Layout::kStrideRank;
|
||||
|
||||
368
include/cutlass/tensor_ref_planar_complex.h
Normal file
368
include/cutlass/tensor_ref_planar_complex.h
Normal file
@ -0,0 +1,368 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a structure containing strides, bounds, and a pointer to tensor data.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element_>
|
||||
struct PlanarComplexReference {
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
using Element = Element_;
|
||||
using ComplexElement = complex<Element>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
Element *real;
|
||||
Element *imag;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PlanarComplexReference(
|
||||
Element *real_ = nullptr,
|
||||
Element *imag_ = nullptr
|
||||
):
|
||||
real(real_), imag(imag_) { }
|
||||
|
||||
/// Loads the complex element
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator complex<Element>() const {
|
||||
return complex<Element>{*real, *imag};
|
||||
}
|
||||
|
||||
/// Stores a complex element to the location pointed to by the reference
|
||||
CUTLASS_HOST_DEVICE
|
||||
PlanarComplexReference &operator=(complex<Element> const &rhs) {
|
||||
*real = rhs.real();
|
||||
*imag = rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank
|
||||
and layout within memory. A TensorRef combines a pointer and a Layout concept
|
||||
|
||||
*/
|
||||
template <
|
||||
/// Data type of element stored within tensor (concept: NumericType)
|
||||
typename Element_,
|
||||
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
|
||||
typename Layout_
|
||||
>
|
||||
class TensorRefPlanarComplex {
|
||||
public:
|
||||
/// Data type of individual access
|
||||
using Element = Element_;
|
||||
|
||||
/// Complex element type
|
||||
using ComplexElement = complex<Element>;
|
||||
|
||||
/// Mapping function from logical coordinate to linear memory
|
||||
using Layout = Layout_;
|
||||
|
||||
static_assert(sizeof_bits<Element>::value >= 8,
|
||||
"Planar complex not suitable for subbyte elements at this time");
|
||||
|
||||
/// Reference type to an element
|
||||
using Reference = PlanarComplexReference<Element>;
|
||||
|
||||
/// Logical rank of tensor index space
|
||||
static int const kRank = Layout::kRank;
|
||||
|
||||
/// Index type
|
||||
using Index = typename Layout::Index;
|
||||
|
||||
/// Long index used for pointer offsets
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
|
||||
/// Layout's stride vector
|
||||
using Stride = typename Layout::Stride;
|
||||
|
||||
/// TensorRef to constant data
|
||||
using ConstTensorRef = TensorRefPlanarComplex<
|
||||
typename platform::remove_const<Element>::type const,
|
||||
Layout>;
|
||||
|
||||
/// TensorRef to non-constant data
|
||||
using NonConstTensorRef = TensorRefPlanarComplex<
|
||||
typename platform::remove_const<Element>::type,
|
||||
Layout>;
|
||||
|
||||
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
|
||||
/// scalar, but degenerate cases such as these are difficult to accommodate without
|
||||
/// extensive C++ metaprogramming or support for zero-length arrays.
|
||||
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
|
||||
|
||||
private:
|
||||
|
||||
/// Pointer
|
||||
Element* ptr_;
|
||||
|
||||
/// Layout object maps logical coordinates to linear offsets
|
||||
Layout layout_;
|
||||
|
||||
/// Offset to imaginary part
|
||||
LongIndex imaginary_stride_;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a TensorRef with a pointer and layout object.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefPlanarComplex(
|
||||
Element *ptr = nullptr, ///< pointer to start of tensor
|
||||
Layout const &layout = Layout(), ///< layout object containing stride and mapping function
|
||||
LongIndex imaginary_stride = 0
|
||||
):
|
||||
ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) {
|
||||
|
||||
}
|
||||
|
||||
/// Converting constructor from TensorRef to non-constant data.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefPlanarComplex(
|
||||
NonConstTensorRef const &ref ///< TensorRef to non-const data
|
||||
):
|
||||
ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { }
|
||||
|
||||
/// Returns a reference to constant-valued tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstTensorRef const_ref() const {
|
||||
return ConstTensorRef(ptr_, layout_, imaginary_stride_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
NonConstTensorRef non_const_ref() const {
|
||||
return NonConstTensorRef(
|
||||
const_cast<typename platform::remove_const<Element>::type *>(ptr_),
|
||||
layout_,
|
||||
imaginary_stride_);
|
||||
}
|
||||
|
||||
/// Updates only the pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) {
|
||||
ptr_ = ptr;
|
||||
imaginary_stride_ = imaginary_stride;
|
||||
}
|
||||
|
||||
/// Updates the pointer and layout object
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) {
|
||||
ptr_ = ptr;
|
||||
layout_ = layout;
|
||||
imaginary_stride_ = imaginary_stride;
|
||||
}
|
||||
|
||||
/// Returns true if the TensorRef is non-null
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool good() const {
|
||||
return ptr_ != nullptr;
|
||||
}
|
||||
|
||||
/// Returns the pointer to referenced data
|
||||
CUTLASS_HOST_DEVICE
|
||||
Element * data() const { return ptr_; }
|
||||
|
||||
/// Returns the pointer to referenced data
|
||||
CUTLASS_HOST_DEVICE
|
||||
Element * imaginary_data() const { return ptr_ + imaginary_stride_; }
|
||||
|
||||
/// Returns a reference to the element at a given linear index
|
||||
CUTLASS_HOST_DEVICE
|
||||
Reference data(LongIndex idx) const {
|
||||
return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_);
|
||||
}
|
||||
|
||||
/// Returns the layout object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Layout & layout() {
|
||||
return layout_;
|
||||
}
|
||||
|
||||
/// Returns the layout object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Layout layout() const {
|
||||
return layout_;
|
||||
}
|
||||
|
||||
/// Gets the stride to an imaginary element
|
||||
LongIndex imaginary_stride() const {
|
||||
return imaginary_stride_;
|
||||
}
|
||||
|
||||
/// Gets the stride to an imaginary element
|
||||
LongIndex &imaginary_stride() {
|
||||
return imaginary_stride_;
|
||||
}
|
||||
|
||||
/// Returns the layout object's stride vector
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride stride() const {
|
||||
return layout_.stride();
|
||||
}
|
||||
|
||||
/// Returns the layout object's stride vector
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride & stride() {
|
||||
return layout_.stride();
|
||||
}
|
||||
|
||||
/// Returns the layout object's stride in a given physical dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index stride(int dim) const {
|
||||
return layout_.stride().at(dim);
|
||||
}
|
||||
|
||||
/// Returns the layout object's stride in a given physical dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & stride(int dim) {
|
||||
return layout_.stride().at(dim);
|
||||
}
|
||||
|
||||
/// Computes the offset of an index from the origin of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex offset(TensorCoord const& coord) const {
|
||||
return layout_(coord);
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Reference at(TensorCoord const& coord) const {
|
||||
return data(offset(coord));
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Reference operator[](TensorCoord const& coord) const {
|
||||
return data(offset(coord));
|
||||
}
|
||||
|
||||
/// Adds an offset to each pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) {
|
||||
ptr_ += offset_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Adds an offset to each pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) {
|
||||
add_pointer_offset(offset(coord));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefPlanarComplex operator+(TensorCoord const& b) const {
|
||||
TensorRefPlanarComplex result(*this);
|
||||
result.add_coord_offset(b);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefPlanarComplex & operator+=(TensorCoord const& b) {
|
||||
add_coord_offset(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefPlanarComplex operator-(TensorCoord const& b) const {
|
||||
TensorRefPlanarComplex result(*this);
|
||||
result.add_pointer_offset(-offset(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefPlanarComplex & operator-=(TensorCoord const& b) {
|
||||
add_pointer_offset(-offset(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// TensorRef to real-valued tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::TensorRef<Element, Layout> ref_real() const {
|
||||
return cutlass::TensorRef<Element, Layout>(data(), layout());
|
||||
}
|
||||
|
||||
/// TensorRef to real-valued tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::TensorRef<Element, Layout> ref_imag() const {
|
||||
return cutlass::TensorRef<Element, Layout>(imaginary_data(), layout());
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constructs a TensorRef, deducing types from arguments.
|
||||
template <
|
||||
typename Element,
|
||||
typename Layout
|
||||
>
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefPlanarComplex<Element, Layout> make_TensorRefPlanarComplex(
|
||||
Element *ptr,
|
||||
Layout const &layout,
|
||||
int64_t imaginary_stride) {
|
||||
|
||||
return TensorRefPlanarComplex<Element, Layout>(ptr, layout, imaginary_stride);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
293
include/cutlass/tensor_view_planar_complex.h
Normal file
293
include/cutlass/tensor_view_planar_complex.h
Normal file
@ -0,0 +1,293 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a structure containing strides and a pointer to tensor data.
|
||||
|
||||
TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus,
|
||||
it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from
|
||||
data storage and is therefore lightweight and may be embedded in larger tensor objects or
|
||||
memory structures.
|
||||
|
||||
See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to
|
||||
linear memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cmath>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_ref_planar_complex.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Data type of element stored within tensor
|
||||
typename Element_,
|
||||
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
|
||||
typename Layout_
|
||||
>
|
||||
class TensorViewPlanarComplex : public TensorRefPlanarComplex<Element_, Layout_> {
|
||||
public:
|
||||
|
||||
/// Base tensor reference
|
||||
using Base = cutlass::TensorRefPlanarComplex<Element_, Layout_>;
|
||||
|
||||
/// Mapping function from logical coordinate to internal n-D array
|
||||
using Layout = Layout_;
|
||||
|
||||
/// TensorRef pointing to constant memory
|
||||
using ConstTensorRef = typename Base::ConstTensorRef;
|
||||
|
||||
/// Underlying TensorRef type
|
||||
using TensorRef = Base;
|
||||
|
||||
/// Data type of individual access
|
||||
using Element = Element_;
|
||||
|
||||
/// Reference type to an element
|
||||
using Reference = Element &;
|
||||
|
||||
/// Logical rank of tensor index space
|
||||
static int const kRank = Layout::kRank;
|
||||
|
||||
/// Index type
|
||||
using Index = typename Layout::Index;
|
||||
|
||||
/// Long index used for pointer offsets
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
|
||||
/// Coordinate in storage n-D array
|
||||
using Stride = typename Layout::Stride;
|
||||
|
||||
/// TensorView pointing to constant memory
|
||||
using ConstTensorView = TensorViewPlanarComplex<
|
||||
typename platform::remove_const<Element>::type const,
|
||||
Layout>;
|
||||
|
||||
/// TensorView pointing to non-constant memory
|
||||
using NonConstTensorView = TensorViewPlanarComplex<
|
||||
typename platform::remove_const<Element>::type,
|
||||
Layout>;
|
||||
|
||||
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
|
||||
/// scalar, but degenerate cases such as these are difficult to accommodate without
|
||||
/// extensive C++ metaprogramming or support for zero-length arrays.
|
||||
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
|
||||
|
||||
private:
|
||||
|
||||
/// View extent
|
||||
TensorCoord extent_;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a TensorView object
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs a TensorView object
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorViewPlanarComplex(
|
||||
Element *ptr, ///< pointer to start of tensor
|
||||
Layout const &layout, ///< layout object containing stride and mapping function
|
||||
LongIndex imaginary_stride, ///< stride between real and imaginary part
|
||||
TensorCoord const &extent ///< size of the view in logical coordinates
|
||||
):
|
||||
Base(ptr, layout, imaginary_stride), extent_(extent) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs a TensorView object
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorViewPlanarComplex(
|
||||
TensorRef const &ref, ///< pointer and layout object referencing a tensor
|
||||
TensorCoord const &extent ///< logical size of tensor
|
||||
):
|
||||
Base(ref), extent_(extent) {
|
||||
|
||||
}
|
||||
|
||||
/// Converting constructor from TensorRef to non-constant data.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorViewPlanarComplex(
|
||||
NonConstTensorView const &view ///< TensorView to non-const data
|
||||
):
|
||||
Base(view), extent_(view.extent_) { }
|
||||
|
||||
/// Updates the pointer and layout object
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) {
|
||||
Base::reset(ptr, layout, imaginary_stride);
|
||||
this->resize(extent_);
|
||||
}
|
||||
|
||||
/// Changes the size of the view without affecting pointer or layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
void resize(TensorCoord extent) {
|
||||
this->extent_ = extent;
|
||||
}
|
||||
|
||||
/// Returns the extent of the view (the size along each logical dimension).
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord const& extent() const { return extent_; }
|
||||
|
||||
/// Returns the extent along a particular logical dimension.
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index extent(int dim) const { return extent_.at(dim); }
|
||||
|
||||
/// Determines whether a location is within a tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool contains(TensorCoord const& coord) const {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int dim = 0; dim < kRank; ++dim) {
|
||||
if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef pointing to the first element of the tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
Base ref() const {
|
||||
return Base(this->data(), this->layout(), this->imaginary_stride());
|
||||
}
|
||||
|
||||
/// Returns a TensorRef pointing to the first element of the tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstTensorRef const_ref() const {
|
||||
return ConstTensorRef(this->data(), this->layout());
|
||||
}
|
||||
|
||||
/// Returns a TensorView to const data
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstTensorView const_view() const {
|
||||
return ConstTensorView(const_ref(), extent_);
|
||||
}
|
||||
|
||||
/// Returns a Tensor_view given location and size quantities
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorViewPlanarComplex subview(
|
||||
TensorCoord extent, ///< extent of the resulting view
|
||||
TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view
|
||||
) const {
|
||||
|
||||
return TensorViewPlanarComplex(ref(), extent.clamp(extent_ - location)).add_coord_offset(location);
|
||||
}
|
||||
|
||||
/// Returns the number of scalar elements needed to store tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
size_t capacity() const {
|
||||
return Base::layout().capacity(extent_);
|
||||
}
|
||||
|
||||
/// Returns a TensorView offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorViewPlanarComplex operator+(
|
||||
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
||||
) const {
|
||||
|
||||
TensorViewPlanarComplex result(*this);
|
||||
result.add_pointer_offset(this->offset(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorViewPlanarComplex& operator+=(
|
||||
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
||||
) {
|
||||
|
||||
this->add_pointer_offset(this->offset(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorViewPlanarComplex operator-(
|
||||
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
||||
) const {
|
||||
|
||||
TensorRef result(*this);
|
||||
result.add_pointer_offset(-this->offset(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorViewPlanarComplex& operator-=(
|
||||
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
||||
) {
|
||||
|
||||
this->add_pointer_offset(-this->offset(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// TensorRef to real-valued tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::TensorView<Element, Layout> view_real() const {
|
||||
return cutlass::TensorView<Element, Layout>(this->data(), this->layout(), extent_);
|
||||
}
|
||||
|
||||
/// TensorRef to real-valued tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::TensorView<Element, Layout> view_imag() const {
|
||||
return cutlass::TensorView<Element, Layout>(this->imaginary_data(), this->layout(), extent_);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constructs a TensorRef, deducing types from arguments.
|
||||
template <
|
||||
typename Element,
|
||||
typename Layout
|
||||
>
|
||||
CUTLASS_HOST_DEVICE TensorViewPlanarComplex<Element, Layout> make_TensorViewPlanarComplex(
|
||||
Element *ptr,
|
||||
Layout const &layout,
|
||||
typename Layout::LongIndex imaginary_stride,
|
||||
typename Layout::TensorCoord const &extent) {
|
||||
|
||||
return TensorViewPlanarComplex<Element, Layout>(ptr, layout, imaginary_stride, extent);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -128,13 +128,13 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
int stride_;
|
||||
/// amount (in byte) to increment pointer to move to next access along
|
||||
/// strided dimension
|
||||
int inc_strided_;
|
||||
LongIndex inc_strided_;
|
||||
/// amount (in byte) to increment pointer from last access to first access
|
||||
/// of next tile
|
||||
int inc_next_;
|
||||
LongIndex inc_next_;
|
||||
/// amount (in byte) to increment pointer from first access of current tile
|
||||
/// to first access of next tile
|
||||
int inc_advance_;
|
||||
LongIndex inc_advance_;
|
||||
|
||||
public:
|
||||
|
||||
@ -145,20 +145,20 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
/// Construct the Params object given a pitch-linear tensor's layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const &layout) : stride_(layout.stride(0)) {
|
||||
inc_strided_ = (stride_ * ThreadMap::Delta::kStrided) *
|
||||
inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) *
|
||||
sizeof_bits<Element>::value / 8;
|
||||
|
||||
if (kAdvanceRank) {
|
||||
// advance along strided dimension
|
||||
inc_advance_ =
|
||||
Shape::kStrided * stride_ * sizeof_bits<Element>::value / 8;
|
||||
Shape::kStrided * LongIndex(stride_) * sizeof_bits<Element>::value / 8;
|
||||
} else {
|
||||
// advance along contiguous dimension
|
||||
inc_advance_ = Shape::kContiguous * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
|
||||
ThreadMap::Delta::kStrided * stride_ *
|
||||
inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) *
|
||||
ThreadMap::Delta::kStrided * LongIndex(stride_) *
|
||||
sizeof_bits<Element>::value / 8;
|
||||
};
|
||||
};
|
||||
@ -280,7 +280,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
TensorCoord residue_extent;
|
||||
if (kAdvanceRank) {
|
||||
|
||||
Index residue_size = (extent_[kAdvanceRank] % Shape::kStrided);
|
||||
Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided;
|
||||
if (!residue_size) {
|
||||
residue_size = Shape::kStrided;
|
||||
}
|
||||
@ -288,18 +288,19 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
residue_offset_ = make_Coord(0, residue_size);
|
||||
residue_extent = make_Coord(
|
||||
extent_.contiguous(),
|
||||
min(threadblock_offset.strided() + residue_offset_.strided(), extent_.strided())
|
||||
min(threadblock_offset.strided() + residue_size, extent_.strided())
|
||||
);
|
||||
|
||||
} else {
|
||||
|
||||
Index residue_size = (extent_[kAdvanceRank] % Shape::kContiguous);
|
||||
Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous;
|
||||
if (!residue_size) {
|
||||
residue_size = Shape::kContiguous;
|
||||
}
|
||||
|
||||
residue_offset_ = make_Coord(residue_size, 0);
|
||||
|
||||
residue_extent = make_Coord(
|
||||
min(extent_.contiguous(), threadblock_offset.contiguous() + residue_offset_.contiguous()),
|
||||
min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size),
|
||||
extent_.strided()
|
||||
);
|
||||
}
|
||||
@ -362,18 +363,18 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
compute_predicates_(extent_, true);
|
||||
|
||||
if (kAdvanceRank) {
|
||||
pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
|
||||
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1);
|
||||
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
||||
} else {
|
||||
pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
|
||||
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
|
||||
pointer_ += Shape::kStrided * tile_offset.strided();
|
||||
}
|
||||
} else {
|
||||
if (kAdvanceRank) {
|
||||
pointer_ += params_.inc_advance_ * tile_offset.strided();
|
||||
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
|
||||
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
||||
} else {
|
||||
pointer_ += params_.inc_advance_ * tile_offset.contiguous();
|
||||
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
|
||||
pointer_ += Shape::kStrided * tile_offset.strided();
|
||||
}
|
||||
}
|
||||
|
||||
@ -296,7 +296,12 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
||||
|
||||
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -310,10 +315,12 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
||||
|
||||
address_iterator_.set_iteration_index(idx);
|
||||
auto ptr = (address_iterator_.get() + pointer_offset);
|
||||
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
|
||||
|
||||
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
||||
|
||||
if (address_iterator_.valid()) {
|
||||
frag_ptr[idx] = *ptr;
|
||||
frag_ptr[idx] = *access_ptr;
|
||||
}
|
||||
++address_iterator_;
|
||||
}
|
||||
@ -323,11 +330,17 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
|
||||
void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
||||
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
||||
address_iterator_.set_iteration_index(0);
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
@ -340,8 +353,11 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
|
||||
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
||||
|
||||
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
|
||||
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
|
||||
|
||||
if (address_iterator_.valid()) {
|
||||
*(address_iterator_.get() + pointer_offset) = frag_ptr[idx];
|
||||
*access_ptr = frag_ptr[idx];
|
||||
}
|
||||
++address_iterator_;
|
||||
}
|
||||
@ -351,7 +367,7 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
|
||||
void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -528,6 +544,12 @@ public:
|
||||
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
||||
iterator_.load_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
@ -540,6 +562,12 @@ public:
|
||||
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
||||
iterator_.store_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) {
|
||||
@ -721,6 +749,12 @@ public:
|
||||
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
||||
iterator_.load_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
@ -732,6 +766,12 @@ public:
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
||||
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
||||
}
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
||||
iterator_.store_with_byte_offset(frag, byte_offset);
|
||||
}
|
||||
|
||||
/// Store a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user