CUTLASS 2.1 (#83)

CUTLASS 2.1 contributes:
- BLAS-style host-side API added to CUTLASS Library
- Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
- Minor enhancements and bug fixes
This commit is contained in:
Andrew Kerr
2020-04-07 13:51:25 -07:00
committed by GitHub
parent 7c0cd26d13
commit 96dab34ad9
196 changed files with 20653 additions and 1995 deletions

View File

@ -1,6 +1,14 @@
# NVIDIA CUTLASS Changelog
# CUTLASS 2.0
# CUTLASS 2.x
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
* API to launch compiled kernel instances for GEMM and planar complex GEMM
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
* [SDK Examples of Planar Complex GEMMs](/examples/10_planar_complex/planar_complex.cu)
* Minor enhancements and bug fixes
## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19)
* Substantially refactored for
@ -22,7 +30,7 @@
* Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions
* Batched GEMM operations
* Complex-valued GEMMs
* Note: a host compiler supporting C++11 or greater is required.
* **Note: a host compiler supporting C++11 or greater is required.**
# CUTLASS 1.x
@ -76,7 +84,7 @@
## Copyright
Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
```
Redistribution and use in source and binary forms, with or without modification, are permitted

235
CMakeLists.txt Normal file → Executable file
View File

@ -32,15 +32,14 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
project(CUTLASS VERSION 2.0.0 LANGUAGES CXX)
project(CUTLASS VERSION 2.1.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
find_package(Doxygen QUIET)
#
# CUTLASS 2.0 requires C++11
# CUTLASS 2.x requires C++11
#
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
@ -49,7 +48,7 @@ if(CUTLASS_NATIVE_CUDA)
set(CMAKE_CUDA_STANDARD 11)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
else()
string(APPEND NVCC_FLAGS " --std=c++11")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++11)
endif()
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
@ -58,13 +57,26 @@ endif()
message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}")
if(${CMAKE_PROJECT_NAME} MATCHES ${PROJECT_NAME})
set(_CUTLASS_ENABLE_TESTS ON)
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
if(CUTLASS_ENABLE_HEADERS_ONLY)
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
set(CUTLASS_ENABLE_TOOLS_INIT OFF)
else()
set(_CUTLASS_ENABLE_TESTS OFF)
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
set(CUTLASS_ENABLE_TOOLS_INIT ON)
endif()
set(CUTLASS_ENABLE_TESTS ${_CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS Tests")
set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples")
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT})
else()
set(CUTLASS_ENABLE_TESTS_INIT OFF)
endif()
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
if (CUTLASS_ENABLE_TESTS)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
@ -86,31 +98,22 @@ endif()
if (NOT CUDA_VERSION VERSION_LESS 10.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75)
endif()
if(CUDA_COMPILER MATCHES "[Cc]lang")
if(NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" )
message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" )
endif()
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
endif()
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
# Special policy introduced in CMake 3.13
if (POLICY CMP0076)
cmake_policy(SET CMP0076 NEW)
endif()
endif()
# check if the configuration is supported
if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
endif()
include(GNUInstallDirs)
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
###################################################################################################
#
# Configure CMake variables
@ -120,11 +123,14 @@ include(GNUInstallDirs)
message(STATUS "CUDA Compilation Architectures: ${CUTLASS_NVCC_ARCHS_ENABLED}")
if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
# By default we want to build in Release mode to ensure that we're getting best performance.
# By default we want to build in Release mode to ensure that we're getting best performance.
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "RelWithDebInfo" "Release")
endif()
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CUTLASS_LIBRARY_DEBUG_POSTFIX ".debug" CACHE STRING "Default postfix value for debug libraries")
if(WIN32)
# On Windows we link against the shared (DLL) runtime. Change gtest settings to match this.
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE)
@ -132,29 +138,35 @@ endif()
if (WIN32)
# Enable more warnings and treat as errors
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX)
# Disable warning on Unicode characters
string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819)
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/fp:strict)
endif(WIN32)
if (${CUTLASS_NVCC_VERBOSE})
string(APPEND NVCC_FLAGS " -v")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v)
endif()
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
set(CUTLASS_ENABLE_F16C ON CACHE BOOL "Enable F16C x86 extensions in host code.")
set(CUTLASS_LIBRARY_KERNELS "128x128" CACHE STRING "Comma delimited list of kernel name filters. Default '' means all kernels are enabled.")
#
# CUTLASS generator cmake configuration
#
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
# Test Levels L0, L1, L2
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
string(APPEND NVCC_FLAGS " -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL}")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
#
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
@ -166,7 +178,7 @@ else()
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON)
endif()
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
"Enable PTX mma instruction for collective matrix multiply operations.")
#
@ -182,7 +194,7 @@ set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CAC
# ...
#
if(ENABLE_ASAN) # https://github.com/google/sanitizers/wiki/AddressSanitizer
string(APPEND NVCC_FLAGS " --compiler-options -fsanitize=address --compiler-options -fno-omit-frame-pointer")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --compiler-options=-fsanitize=address --compiler-options=-fno-omit-frame-pointer)
string(APPEND CMAKE_EXE_LINKER_FLAGS " -fsanitize=address")
endif()
@ -192,85 +204,127 @@ endif()
#
###################################################################################################
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
if(CUTLASS_NVCC_EMBED_CUBIN)
string(APPEND NVCC_GENCODE_FLAGS " -gencode=arch=compute_${ARCH},code=sm_${ARCH}")
endif()
if(CUTLASS_NVCC_EMBED_PTX)
string(APPEND NVCC_GENCODE_FLAGS " -gencode=arch=compute_${ARCH},code=compute_${ARCH}")
endif()
string(APPEND CLANG_FLAGS " --cuda-gpu-arch=sm_${ARCH}")
endforeach()
if(CUTLASS_NVCC_EMBED_PTX)
string(APPEND CLANG_FLAGS " --cuda-include-ptx=all")
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-include-ptx=all)
endif()
if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
string(APPEND COMMON_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1")
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)
endif()
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
# MSVC flow handles caching already, but for other generators we handle it here.
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR})
string(APPEND NVCC_FLAGS " --keep") # --keep-dir may not work with nvcc for some directories.
string(APPEND CLANG_FLAGS " -save-temps=${CUTLASS_NVCC_KEEP_DIR}")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep) # --keep-dir may not work with nvcc for some directories.
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR})
endif()
if (CUTLASS_ENABLE_F16C)
string(APPEND COMPILER_FLAGS " -DCUTLASS_ENABLE_F16C=1")
if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING)
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_F16C=1)
if ((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang"))
string(APPEND NVCC_FLAGS " -Xcompiler -mf16c")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-mf16c)
elseif((CMAKE_CXX_COMPILER_ID MATCHES "MSVC"))
string(APPEND NVCC_FLAGS " -Xcompiler /arch:AVX2")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/arch:AVX2)
endif()
endif()
string(APPEND NVCC_FLAGS " -lineinfo")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-Wconversion>)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-fno-strict-aliasing>)
string(APPEND CLANG_FLAGS " -gmlt")
if (UNIX)
string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
string(APPEND NVCC_FLAGS " -Xcompiler -fno-strict-aliasing")
# Don't leak lineinfo in release builds
if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -gmlt)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo)
endif()
if(CUDA_COMPILER MATCHES "[Cc]lang")
string(APPEND CLANG_FLAGS " --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}")
string(APPEND CLANG_FLAGS " -mllvm -pragma-unroll-threshold=100000")
string(APPEND CLANG_FLAGS " -mllvm -unroll-threshold=5000")
string(APPEND CLANG_FLAGS " -Wno-unused-command-line-argument")
if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" )
message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" )
endif()
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
endif()
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-pragma-unroll-threshold=100000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-unroll-threshold=5000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)
string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION})
list(GET CUDA_VERSION_PARTS 0 CUDA_VERSION_MAJOR)
list(GET CUDA_VERSION_PARTS 1 CUDA_VERSION_MINOR)
string(APPEND CLANG_FLAGS " -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR}")
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR})
# needed for libcublasLt.so in case it's installed in the same location as libcudart.so
# dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags)
# Otherwise linker uses RUNPATH and that does not propagate to loaded libs.
string(APPEND CLANG_FLAGS " -Wl,--disable-new-dtags")
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wl,--disable-new-dtags)
link_libraries(nvidia::cudart)
endif()
if(CUDA_COMPILER MATCHES "[Cc]lang")
string(APPEND CMAKE_CXX_FLAGS "${COMMON_FLAGS} ${CLANG_FLAGS}")
string(APPEND CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${CLANG_FLAGS_RELEASE}")
string(APPEND CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${CLANG_FLAGS_RELWITHDEBINFO}")
string(APPEND CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${CLANG_FLAGS_DEBUG}")
elseif (CUTLASS_NATIVE_CUDA)
string(APPEND CMAKE_CUDA_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS} ${NVCC_GENCODE_FLAGS}")
string(APPEND CMAKE_CUDA_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${NVCC_FLAGS_RELEASE}")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${NVCC_FLAGS_RELWITHDEBINFO}")
string(APPEND CMAKE_CUDA_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${NVCC_FLAGS_DEBUG}")
else()
string(APPEND CUDA_NVCC_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS} ${NVCC_GENCODE_FLAGS}")
string(APPEND CUDA_NVCC_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${NVCC_FLAGS_RELEASE}")
string(APPEND CUDA_NVCC_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${NVCC_FLAGS_RELWITHDEBINFO}")
string(APPEND CUDA_NVCC_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${NVCC_FLAGS_DEBUG}")
endif()
function(cutlass_apply_cuda_gencode_flags TARGET)
set(NVCC_FLAGS)
set(CLANG_FLAGS)
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
set(CODES)
if(CUTLASS_NVCC_EMBED_CUBIN)
list(APPEND CODES sm_${ARCH})
endif()
if(CUTLASS_NVCC_EMBED_PTX)
list(APPEND CODES compute_${ARCH})
endif()
list(JOIN CODES "," CODES_STR)
list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}])
endforeach()
if (CUDA_COMPILER MATCHES "[Cc]lang")
target_compile_options(
${TARGET}
PRIVATE
$<$<COMPILE_LANGUAGE:CXX>:${CLANG_FLAGS}>
)
else()
target_compile_options(
${TARGET}
PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:${NVCC_FLAGS}>
)
endif()
endfunction()
function(cutlass_apply_standard_compile_options TARGET)
if(CUDA_COMPILER MATCHES "[Cc]lang")
set(CUDA_COMPILE_LANGUAGE CXX)
set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_CLANG_FLAGS})
set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_CLANG_FLAGS_RELEASE})
set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO})
set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_CLANG_FLAGS_DEBUG})
else()
set(CUDA_COMPILE_LANGUAGE CUDA)
set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_NVCC_FLAGS})
set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_NVCC_FLAGS_RELEASE})
set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_NVCC_FLAGS_RELWITHDEBINFO})
set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG})
endif()
target_compile_options(
${TARGET}
PRIVATE
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:${_FLAGS}>
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:RELEASE>:${_FLAGS_RELEASE}>>
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:RELWITHDEBINFO>:${_FLAGS_RELWITHDEBINFO}>>
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:DEBUG>:${_FLAGS_DEBUG}>>
)
endfunction()
#
# The following items should eventually be pushed into cutlass/CMakeLists.txt
@ -324,8 +378,8 @@ if (NOT DEFINED CUTLASS_REVISION)
endif()
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h
@ONLY)
target_include_directories(
@ -338,8 +392,8 @@ target_include_directories(
)
install(
DIRECTORY
${CUTLASS_INCLUDE_DIR}/
DIRECTORY
${CUTLASS_INCLUDE_DIR}/
${CMAKE_CURRENT_BINARY_DIR}/include/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
@ -399,27 +453,6 @@ endif()
################################################################################
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
if(CUTLASS_ENABLE_HEADERS_ONLY)
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
set(CUTLASS_ENABLE_TOOLS_INIT OFF)
else()
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
set(CUTLASS_ENABLE_TOOLS_INIT ON)
endif()
set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples")
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT})
else()
set(CUTLASS_ENABLE_TESTS_INIT OFF)
endif()
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
if(CUTLASS_ENABLE_TOOLS)
add_subdirectory(tools)
endif()

View File

@ -39,23 +39,27 @@ if(CUTLASS_NATIVE_CUDA)
enable_language(CUDA)
if(NOT CUDA_VERSION)
set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION})
endif()
if(NOT CUDA_TOOLKIT_ROOT_DIR)
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE)
endif()
else()
find_package(CUDA REQUIRED)
# We workaround missing variables with the native flow by also finding the CUDA toolkit the old way.
endif()
if(NOT CMAKE_CUDA_COMPILER_VERSION)
set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION})
endif()
if(NOT CUDA_VERSION)
set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION})
endif()
if(NOT CUDA_TOOLKIT_ROOT_DIR)
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE)
endif()
if (CUDA_VERSION VERSION_LESS 9.2)
message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.")
endif()
if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang")
set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc)
message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}")
@ -74,7 +78,7 @@ find_library(
# in the CUDA toolkit we're building against.
)
if(CUDART_LIBRARY)
if(NOT TARGET cudart AND CUDART_LIBRARY)
message(STATUS "CUDART: ${CUDART_LIBRARY}")
@ -95,6 +99,10 @@ if(CUDART_LIBRARY)
${CUDART_LIBRARY}
)
elseif(TARGET cudart)
message(STATUS "CUDART: Already Found")
else()
message(STATUS "CUDART: Not Found")
@ -116,7 +124,7 @@ find_library(
# in the CUDA toolkit we're building against.
)
if(CUDA_DRIVER_LIBRARY)
if(NOT TARGET cuda_driver AND CUDA_DRIVER_LIBRARY)
message(STATUS "CUDA Driver: ${CUDA_DRIVER_LIBRARY}")
@ -137,6 +145,10 @@ if(CUDA_DRIVER_LIBRARY)
${CUDA_DRIVER_LIBRARY}
)
elseif(TARGET cuda_driver)
message(STATUS "CUDA Driver: Already Found")
else()
message(STATUS "CUDA Driver: Not Found")
@ -156,7 +168,7 @@ find_library(
# in the CUDA toolkit we're building against.
)
if(NVRTC_LIBRARY)
if(NOT TARGET nvrtc AND NVRTC_LIBRARY)
message(STATUS "NVRTC: ${NVRTC_LIBRARY}")
@ -177,6 +189,10 @@ if(NVRTC_LIBRARY)
${NVRTC_LIBRARY}
)
elseif(TARGET nvrtc)
message(STATUS "NVRTC: Already Found")
else()
message(STATUS "NVRTC: Not Found")
@ -197,48 +213,137 @@ function(cutlass_correct_source_file_language_property)
endif()
endfunction()
function(cutlass_add_library)
set(CUTLASS_UNITY_BUILD_ENABLED ON CACHE BOOL "Enable combined source compilation")
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
set(options INTERFACE STATIC SHARED OBJECT)
set(oneValueArgs)
function(cutlass_unify_source_files TARGET_ARGS_VAR)
set(options)
set(oneValueArgs BATCH_SOURCES BATCH_SIZE)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang" OR __INTERFACE)
cutlass_correct_source_file_language_property(${ARGN})
add_library(${ARGN})
else()
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
cuda_add_library(${ARGN})
if (NOT DEFINED TARGET_ARGS_VAR)
message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required")
endif()
if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE)
set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE})
endif()
if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1)
set(CUDA_FILE_ARGS)
set(TARGET_SOURCE_ARGS)
foreach(ARG ${__UNPARSED_ARGUMENTS})
if(${ARG} MATCHES ".*\.cu$")
list(APPEND CUDA_FILE_ARGS ${ARG})
else()
list(APPEND TARGET_SOURCE_ARGS ${ARG})
endif()
endforeach()
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
while(NUM_CUDA_FILE_ARGS GREATER 0)
list(SUBLIST CUDA_FILE_ARGS 0 ${__BATCH_SIZE} CUDA_FILE_BATCH)
string(SHA256 CUDA_FILE_BATCH_HASH "${CUDA_FILE_BATCH}")
string(SUBSTRING ${CUDA_FILE_BATCH_HASH} 0 12 CUDA_FILE_BATCH_HASH)
set(BATCH_FILE ${CMAKE_CURRENT_BINARY_DIR}/${NAME}.unity.${CUDA_FILE_BATCH_HASH}.cu)
message(STATUS "Generating ${BATCH_FILE}")
file(WRITE ${BATCH_FILE} "// Unity File - Auto Generated!\n")
foreach(CUDA_FILE ${CUDA_FILE_BATCH})
get_filename_component(CUDA_FILE_ABS_PATH ${CUDA_FILE} ABSOLUTE)
file(APPEND ${BATCH_FILE} "#include \"${CUDA_FILE_ABS_PATH}\"\n")
endforeach()
list(APPEND TARGET_SOURCE_ARGS ${BATCH_FILE})
if (NUM_CUDA_FILE_ARGS LESS_EQUAL __BATCH_SIZE)
break()
endif()
list(SUBLIST CUDA_FILE_ARGS ${__BATCH_SIZE} -1 CUDA_FILE_ARGS)
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
endwhile()
else()
set(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
endif()
set(${TARGET_ARGS_VAR} ${TARGET_SOURCE_ARGS} PARENT_SCOPE)
endfunction()
function(cutlass_add_executable)
function(cutlass_add_library NAME)
set(options)
set(oneValueArgs)
set(oneValueArgs EXPORT_NAME)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
cutlass_correct_source_file_language_property(${ARGN})
add_executable(${ARGN})
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
add_library(${NAME} ${TARGET_SOURCE_ARGS})
else()
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
cuda_add_executable(${ARGN})
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS})
endif()
cutlass_apply_standard_compile_options(${NAME})
cutlass_apply_cuda_gencode_flags(${NAME})
target_compile_features(
${NAME}
INTERFACE
cxx_std_11
)
if(__EXPORT_NAME)
add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME})
set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME})
endif()
endfunction()
function(cutlass_target_sources)
function(cutlass_add_executable NAME)
set(options)
set(oneValueArgs)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cutlass_correct_source_file_language_property(${ARGN})
target_sources(${ARGN})
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
add_executable(${NAME} ${TARGET_SOURCE_ARGS})
else()
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS})
endif()
cutlass_apply_standard_compile_options(${NAME})
cutlass_apply_cuda_gencode_flags(${NAME})
target_compile_features(
${NAME}
INTERFACE
cxx_std_11
)
endfunction()
function(cutlass_target_sources NAME)
set(options)
set(oneValueArgs)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
target_sources(${NAME} ${TARGET_SOURCE_ARGS})
endfunction()

View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 2.0
# CUTLASS 2.1
_CUTLASS 2.0 - November 2019_
_CUTLASS 2.1 - April 2020_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
@ -25,6 +25,14 @@ NVIDIA's Volta and Turing architectures.
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
# What's New in CUTLASS 2.1
CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
# What's New in CUTLASS 2.0
CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
@ -33,7 +41,7 @@ CUTLASS 2.0 is a substantial refactoring from the previous version, intended to
- Robust and durable templates that reliably span the design space
- Encapsulated functionality that may be reusable in other contexts
See the [CHANGELOG](CHANGELOG.md) for more details.
**See the [CHANGELOG](CHANGELOG.md) for more details.**
See the [functionality listing](media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
@ -82,7 +90,7 @@ any Maxwell-, Pascal-, Volta-, or Turing- architecture NVIDIA GPU.
# Documentation
CUTLASS 2.0 is described in the following documents and the accompanying
CUTLASS 2.1 is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
@ -177,7 +185,7 @@ include/ # client applications should target this directory
### CUTLASS SDK Examples
CUTLASS SDK examples apply CUTLASS templates to implement basic computations.
[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations.
```
examples/
@ -198,12 +206,19 @@ examples/
07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores
08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores
10_planar_complex/ # example demonstrating planar complex GEMM kernels
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
```
### Tools
```
tools/
library/ # CUTLASS Instance Library - contains instantiations of all supported CUTLASS templates
include/
cutlass/
library/
profiler/ # CUTLASS Profiler - command-line utility for executing operations in the
# CUTLASS Library
@ -279,7 +294,7 @@ The official list of CUTLASS developers and contributors is available here: [CON
# Copyright
Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
```
Redistribution and use in source and binary forms, with or without modification, are permitted

View File

@ -1,7 +1,8 @@
message(STATUS "Configuring cublas ...")
if(DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS)
if((DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) OR
(DEFINED CUBLAS_ENABLED AND NOT CUBLAS_ENABLED))
# Don't add cuBLAS if it's defined and false, assume it's not found.
@ -59,11 +60,13 @@ endif()
if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
if(WIN32)
add_library(cublas STATIC IMPORTED)
add_library(cublas STATIC IMPORTED GLOBAL)
else()
add_library(cublas SHARED IMPORTED)
add_library(cublas SHARED IMPORTED GLOBAL)
endif()
add_library(nvidia::cublas ALIAS cublas)
set_property(
TARGET cublas
PROPERTY IMPORTED_LOCATION
@ -88,23 +91,20 @@ if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
${CUBLAS_PATH}/lib/x64
/usr/lib/x86_64-linux-gnu)
if(_CUBLASLT_LIBRARY)
if(_CUBLASLT_LIBRARY AND NOT TARGET cublasLt)
if(WIN32)
add_library(cublasLt STATIC IMPORTED)
add_library(cublasLt STATIC IMPORTED GLOBAL)
else()
add_library(cublasLt SHARED IMPORTED)
add_library(cublasLt SHARED IMPORTED GLOBAL)
endif()
set_property(
TARGET cublasLt
PROPERTY IMPORTED_LOCATION
${_CUBLASLT_LIBRARY})
target_link_libraries(
cublas
INTERFACE
cublasLt)
add_library(nvidia::cublasLt ALIAS cublasLt)
endif()

View File

@ -173,14 +173,35 @@ using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
EpilogueOp>;
int main() {
//
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
//
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
return -1;
}
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (!(props.major >= 7)) {
std::cerr << "Volta Tensor Ops must be run on a machine with compute capability at least 70."
<< std::endl;
return 0;
return -1;
}
//
// Define problem size
//
const int length_m = 5120;
const int length_n = 4096;
const int length_k = 4096;

View File

@ -29,7 +29,7 @@ provided by CUTLASS using tensor cores; which we run on a NVIDIA Volta GPU.
Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing
high performance kernels at scale which works for multiple problem sizes with good abstractions is
really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU
easily.
@ -189,13 +189,27 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
NumStages>;
int main() {
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
//
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
return -1;
}
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (!(props.major >= 7)) {
std::cerr << "Volta Tensor Ops must be run on a machine with compute capability at least 70."
<< std::endl;
return 0;
return -1;
}
const int length_m = 5120;

View File

@ -29,7 +29,7 @@ provided by CUTLASS using tensor cores; which we run on a NVIDIA Turing GPU.
Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing
high performance kernels at scale which works for multiple problem sizes with good abstractions is
really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU
easily.
@ -187,13 +187,28 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
NumStages>;
int main() {
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major >= 7 && props.minor >= 5)) {
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 10.2.
//
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
return -1;
}
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (!((props.major * 10 + props.minor) >= 75)) {
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75."
<< std::endl;
return 0;
return -1;
}
const int length_m = 5120;

View File

@ -0,0 +1,41 @@
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this list of
# conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice, this list of
# conditions and the following disclaimer in the documentation and/or other materials
# provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Planar Complex GEMM example
cutlass_example_add_executable(
10_planar_complex
planar_complex.cu
)
#
# This example depends on the CUTLASS Library
#
target_link_libraries(
10_planar_complex
PRIVATE
cutlass_lib
cutlass_tools_util_includes
)

View File

@ -0,0 +1,551 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Planar Complex GEMM
This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels supporting
the batched strided mode.
These kernels represent complex matrices by storing the real and imaginary parts of the matrix in
disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts
as either column-major or row-major layouts with a single leading dimension indicating the stride
between columns or rows.
The CUTLASS Library collects multiple template instantiations in a data structure and offers
a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures.
CUTLASS decouples matrix layout from complex transformation, so four possible transformations
are possible on the A and B operands:
n: column-major
c: column-major complex conjugate
t: row-major
h: row-major complex conjugate
The CUTLASS Library contains many kernel instances specialized for architecture, data type, tile
size, and alignment. This can result in long compile times.
To build strictly the planar complex kernels needed for general application, execute the following
CMake command in an empty build directory.
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex
This builds all planar complex GEMM variants for Volta and Turing architectures.
To build strictly the kernels needed for this example, an even narrower filter string may be
specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for
the 'CN' layout configuration (conjugate A operand with both A and B as column-major).
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_f16*cn
$ make 10_planar_complex
$ ./examples/10_planar_complex/10_planar_complex --m=2048 --n=1024 --k=512 --batch=10
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/host_tensor_planar_complex.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/gemm_planar_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/library/handle.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
cutlass::gemm::GemmCoord problem_size;
int batch_count;
cutlass::complex<float> alpha;
cutlass::complex<float> beta;
bool reference_check;
int iterations;
Options():
help(false),
problem_size({1024, 1024, 1024}),
batch_count(1),
reference_check(true),
iterations(20),
alpha(1),
beta() { }
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("batch", batch_count);
cmd.get_cmd_line_argument("alpha", alpha.real());
cmd.get_cmd_line_argument("alpha_i", alpha.imag());
cmd.get_cmd_line_argument("beta", beta.real());
cmd.get_cmd_line_argument("beta_i", beta.imag());
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "10_planar_complex example\n\n"
<< " This example uses the CUTLASS Library to execute Planar Complex GEMM computations.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m <int> GEMM M dimension\n"
<< " --n <int> GEMM N dimension\n"
<< " --k <int> GEMM K dimension\n"
<< " --batch <int> Number of GEMM operations executed in one batch\n"
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
<< " --alpha_i <f32> Epilogue scalar alpha (imaginary part)\n"
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
<< " --beta_i <f32> Epilogue scalar beta (imaginary part)\n\n"
<< " --iterations <int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/10_planar_complex/10_planar_complex --batch=7 --m=1024 --n=512 --k=1024 \\\n"
<< " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = problem_size.product() * batch_count * 4;
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Performance test environment for planar complex
class TestbedPlanarComplex {
public:
using ElementA = cutlass::half_t;
using LayoutA = cutlass::layout::ColumnMajor;
using ElementB = cutlass::half_t;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = cutlass::half_t;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementCompute = float;
using ElementAccumulator = float;
//
// Data members
//
cutlass::library::Handle handle;
cutlass::gemm::GemmCoord problem_size;
int batch_count;
cutlass::DeviceAllocation<ElementA> tensor_A;
cutlass::DeviceAllocation<ElementB> tensor_B;
cutlass::DeviceAllocation<ElementC> tensor_C;
cutlass::DeviceAllocation<ElementC> tensor_D;
cutlass::DeviceAllocation<ElementC> tensor_D_ref;
//
// Methods
//
TestbedPlanarComplex(
Options const &options
):
problem_size(options.problem_size), batch_count(options.batch_count) {
// Allocate device memory for batched strided GEMM
tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2);
tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2);
tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
}
void initialize() {
uint64_t seed = 1073;
// Use small integers to simplify correctness checking
int scope_max = 6;
int scope_min = -6;
cutlass::reference::device::BlockFillRandomUniform(
tensor_A.get(), tensor_A.size(), seed, ElementA(scope_max), ElementA(scope_min), 0);
cutlass::reference::device::BlockFillRandomUniform(
tensor_B.get(), tensor_B.size(), seed * 2019, ElementB(scope_max), ElementB(scope_min), 0);
cutlass::reference::device::BlockFillRandomUniform(
tensor_C.get(), tensor_C.size(), seed * 2020, ElementC(scope_max), ElementC(scope_min), 0);
}
Result profile(Options const &options) {
Result result;
initialize();
ElementA *ptr_A = tensor_A.get();
ElementB *ptr_B = tensor_B.get();
ElementC *ptr_C = tensor_C.get();
ElementC *ptr_D = tensor_D.get();
int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2;
int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2;
int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2;
int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2;
int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k();
int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n();
int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n();
int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n();
//
// Construct events
//
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMMs
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
//
// Run profiling loop
//
for (int iter = 0; iter < options.iterations; ++iter) {
//
// Execute the planar complex GEMM kernel via the CUTLASS Library's
// dispatch routines.
//
// Note, for planar complex GEMM kernels, all numeric type arguments
// specify the data type of the base real types. These are understood to
// apply to planar complex representations of matrices in memory and to complex<T>
// structures for scalars.
//
// See tools/library/include/cutlass/library/handle.h for more details.
//
result.status = handle.gemm_planar_complex(
problem_size.m(), // GEMM M dimension
problem_size.n(), // GEMM N dimension
problem_size.k(), // GEMM K dimension
cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation
cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars
&options.alpha, // Pointer to alpha scalar, of type complex<T>
cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix
cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix
cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand
ptr_A, // Pointer to real part of A matrix
ptr_A + imag_stride_A, // Pointer to imaginary part of A matrix
lda, // Leading dimension of real part of A matrix
lda, // Leading dimension of imaginary part of A matrix
cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix
cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix
cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand
ptr_B, // Pointer to real part of B matrix
ptr_B + imag_stride_B, // Pointer to imaginary part of B matrix
ldb, // Leading dimension of real part of B matrix
ldb, // Leading dimension of imaginary part of B matrix
&options.beta, // Pointer to beta scalar, of type complex<T>
cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices
ptr_C, // Pointer to real part of C matrix
ptr_C + imag_stride_C, // Pointer to imaginary part of C matrix
ldc, // Leading dimension of real part of C matrix
ldc, // Leading dimension of imaginary part of C matrix
ptr_D, // Pointer to real part of D matrix
ptr_D + imag_stride_D, // Pointer to imaginary part of D matrix
ldd, // Leading dimension of real part of D matrix
ldd, // Leading dimension of imaginary part of D matrix
batch_count, // Number of batched elements
batch_stride_A, // Stride between batches of real parts of A matrix
batch_stride_A, // Stride between batches of imaginary parts of A matrix
batch_stride_B, // Stride between batches of real parts of B matrix
batch_stride_B, // Stride between batches of imaginary parts of B matrix
batch_stride_C, // Stride between batches of real parts of C matrix
batch_stride_C, // Stride between batches of imaginary parts of C matrix
batch_stride_D, // Stride between batches of real parts of D matrix
batch_stride_D // Stride between batches of imaginary parts of D matrix
);
if (result.status != cutlass::Status::kSuccess) {
std::cerr << "CUTLASS internal error - configuration not supported" << std::endl;
return result;
}
}
//
// Stop profiling loop
//
// Record an event when the GEMMs are complete
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)cudaEventDestroy(event);
}
if (handle.get_last_operation()) {
std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl;
}
//
// Compute reference in device code
//
if (options.reference_check) {
result.passed = true;
for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) {
cutlass::reference::device::GemmPlanarComplex<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementAccumulator
>(
problem_size,
options.alpha,
{tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A},
cutlass::ComplexTransform::kConjugate,
{tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B},
cutlass::ComplexTransform::kNone,
options.beta,
{tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C},
{tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D}
);
ElementC epsilon = 0.1_hf;
ElementC nonzero_floor = 0.1_hf;
result.passed = cutlass::reference::device::BlockCompareRelativelyEqual(
tensor_D.get() + idx * batch_stride_D,
tensor_D_ref.get() + idx * batch_stride_D,
batch_stride_D,
epsilon,
nonzero_floor
);
}
if (result.passed) {
std::cout << "Reference check passed." << std::endl;
}
else {
std::cerr << "Error - reference check failed." << std::endl;
}
}
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << " GFLOPs: " << result.gflops << std::endl;
return result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
//
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
//
// Volta Tensor Core operations are first available in CUDA 10.1 Toolkit.
//
// Turing Tensor Core operations are first available in CUDA 10.2 Toolkit.
//
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (props.major < 7) {
std::cerr << "Volta Tensor Core operations must be run on a machine with compute capability at least 70."
<< std::endl;
return -1;
}
else if (props.major == 7 && props.minor <= 2) {
//
// If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example.
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
return -1;
}
}
else if (props.major == 7 && props.minor >= 5) {
//
// If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example.
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
return -1;
}
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
TestbedPlanarComplex testbed(options);
Result result = testbed.profile(options);
return result.passed ? 0 : -1;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,41 @@
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this list of
# conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice, this list of
# conditions and the following disclaimer in the documentation and/or other materials
# provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Planar Complex Array GEMM example
cutlass_example_add_executable(
11_planar_complex_array
planar_complex_array.cu
)
#
# This example depends on the CUTLASS Library
#
target_link_libraries(
11_planar_complex_array
PRIVATE
cutlass_lib
cutlass_tools_util_includes
)

View File

@ -0,0 +1,611 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Planar Complex Array Example
This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels which
execute a batch of matrix products, loading problem sizes and matrix base pointers from arrays
in global memory.
These kernels represent complex matrices by storing the real and imaginary parts of the matrix in
disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts
as either column-major or row-major layouts with a single leading dimension indicating the stride
between columns or rows.
The CUTLASS Library collects multiple template instantiations in a data structure and offers
a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures.
CUTLASS decouples matrix layout from complex transformation, so four possible transformations
are possible on the A and B operands:
n: column-major
c: column-major complex conjugate
t: row-major
h: row-major complex conjugate
To build strictly the planar complex kernels needed for general application, execute the following
CMake command in an empty build directory.
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex
This builds all planar complex GEMM variants for Volta and Turing architectures.
To build strictly the kernels needed for this example, an even narrower filter string may be
specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for
the 'CN' layout configuration (conjugate A operand with both A and B as column-major).
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_array_f16*cn
$ make 11_planar_complex_array
$ ./examples/11_planar_complex_array/11_planar_complex_array --m=2048 --n=1024 --k=512 --batch=10
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/host_tensor_planar_complex.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/gemm_planar_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/library/handle.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
cutlass::gemm::GemmCoord problem_size;
int batch_count;
cutlass::complex<float> alpha;
cutlass::complex<float> beta;
bool reference_check;
int iterations;
Options():
help(false),
problem_size({1024, 1024, 1024}),
batch_count(1),
reference_check(true),
iterations(20),
alpha(1),
beta() { }
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("batch", batch_count);
cmd.get_cmd_line_argument("alpha", alpha.real());
cmd.get_cmd_line_argument("alpha_i", alpha.imag());
cmd.get_cmd_line_argument("beta", beta.real());
cmd.get_cmd_line_argument("beta_i", beta.imag());
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "11_planar_complex_array example\n\n"
<< " This example uses the CUTLASS Library to execute Planar Complex Array GEMM computations.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m <int> GEMM M dimension\n"
<< " --n <int> GEMM N dimension\n"
<< " --k <int> GEMM K dimension\n"
<< " --batch <int> Number of GEMM operations executed in one batch\n"
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
<< " --alpha_i <f32> Epilogue scalar alpha (imaginary part)\n"
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
<< " --beta_i <f32> Epilogue scalar beta (imaginary part)\n\n"
<< " --iterations <int> Number of profiling iterations to perform.\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/11_planar_complex_array/11_planar_complex_array\n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = problem_size.product() * batch_count * 4;
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Performance test environment for planar complex
class TestbedPlanarComplex {
public:
// Half-precision input and output
using Element = cutlass::half_t;
// Configurations for layouts and internal computation
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using ElementCompute = float;
using ElementAccumulator = float;
//
// Data members
//
cutlass::library::Handle handle;
cutlass::gemm::GemmCoord problem_size;
int batch_count;
cutlass::DeviceAllocation<Element> tensor_A;
cutlass::DeviceAllocation<Element> tensor_B;
cutlass::DeviceAllocation<Element> tensor_C;
cutlass::DeviceAllocation<Element> tensor_D;
cutlass::DeviceAllocation<Element> tensor_D_ref;
cutlass::DeviceAllocation<void *> ptr_A_real;
cutlass::DeviceAllocation<void *> ptr_A_imag;
cutlass::DeviceAllocation<void *> ptr_B_real;
cutlass::DeviceAllocation<void *> ptr_B_imag;
cutlass::DeviceAllocation<void *> ptr_C_real;
cutlass::DeviceAllocation<void *> ptr_C_imag;
cutlass::DeviceAllocation<void *> ptr_D_real;
cutlass::DeviceAllocation<void *> ptr_D_imag;
//
// Methods
//
TestbedPlanarComplex(
Options const &options
):
problem_size(options.problem_size), batch_count(options.batch_count) {
// Allocate device memory for batched planar complex GEMM
tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2);
tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2);
tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2);
ptr_A_real.reset(batch_count);
ptr_A_imag.reset(batch_count);
ptr_B_real.reset(batch_count);
ptr_B_imag.reset(batch_count);
ptr_C_real.reset(batch_count);
ptr_C_imag.reset(batch_count);
ptr_D_real.reset(batch_count);
ptr_D_imag.reset(batch_count);
}
void initialize() {
uint64_t seed = 1073;
// Use small integers to simplify correctness checking
int scope_max = 6;
int scope_min = -6;
cutlass::reference::device::BlockFillRandomUniform(
tensor_A.get(), tensor_A.size(), seed, Element(scope_max), Element(scope_min), 0);
cutlass::reference::device::BlockFillRandomUniform(
tensor_B.get(), tensor_B.size(), seed * 2019, Element(scope_max), Element(scope_min), 0);
cutlass::reference::device::BlockFillRandomUniform(
tensor_C.get(), tensor_C.size(), seed * 2020, Element(scope_max), Element(scope_min), 0);
}
Result profile(Options const &options) {
Result result;
initialize();
Element *ptr_A = tensor_A.get();
Element *ptr_B = tensor_B.get();
Element *ptr_C = tensor_C.get();
Element *ptr_D = tensor_D.get();
int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2;
int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2;
int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2;
int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2;
int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k();
int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n();
int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n();
int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n();
//
// Configure pointers in global memory
//
struct {
Element *base;
void **ptr_real;
void **ptr_imag;
int64_t batch_stride;
int64_t imag_stride;
} tensors[] = {
{ tensor_A.get(), ptr_A_real.get(), ptr_A_imag.get(), batch_stride_A, imag_stride_A},
{ tensor_B.get(), ptr_B_real.get(), ptr_B_imag.get(), batch_stride_B, imag_stride_B},
{ tensor_C.get(), ptr_C_real.get(), ptr_C_imag.get(), batch_stride_C, imag_stride_C},
{ tensor_D.get(), ptr_D_real.get(), ptr_D_imag.get(), batch_stride_D, imag_stride_D}
};
for (auto const &tensor : tensors) {
for (int idx = 0; idx < batch_count; ++idx) {
void *ptr_real = tensor.base + idx * tensor.batch_stride;
void *ptr_imag = tensor.base + idx * tensor.batch_stride + tensor.imag_stride;
cudaError_t error = cudaMemcpy(
tensor.ptr_real + idx,
&ptr_real,
sizeof(void *),
cudaMemcpyHostToDevice);
if (error != cudaSuccess) {
throw std::runtime_error("Failed to copy pointer to device memory");
}
error = cudaMemcpy(
tensor.ptr_imag + idx,
&ptr_imag,
sizeof(void *),
cudaMemcpyHostToDevice);
if (error != cudaSuccess) {
throw std::runtime_error("Failed to copy pointer to device memory");
}
}
}
//
// Construct events
//
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMM operations
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
//
// Run profiling loop
//
for (int iter = 0; iter < options.iterations; ++iter) {
//
// Execute the planar complex array GEMM kernel via the CUTLASS Library's
// dispatch routines.
//
// Note, for planar complex array GEMM kernels, all numeric type arguments
// specify the data type of the base real types. These are understood to
// apply to planar complex representations of matrices in memory and to complex<T>
// structures for scalars.
//
// See tools/library/include/cutlass/library/handle.h for more details.
//
result.status = handle.gemm_planar_complex_array(
problem_size.m(), // expected GEMM M dimension
problem_size.n(), // expected GEMM N dimension
problem_size.k(), // expected GEMM K dimension
batch_count, // Number of batched elements
nullptr,
nullptr,
nullptr,
cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation
cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars
&options.alpha, // Pointer to alpha scalar, of type complex<T>
cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix
cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix
cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand
ptr_A_real.get(), // Pointer to array of pointers to real part of A matrix
ptr_A_imag.get(), // Pointer to array of pointers to imaginary part of A matrix
lda, // Leading dimension of real part of A matrix
lda, // Leading dimension of imaginary part of A matrix
cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix
cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix
cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand
ptr_B_real.get(), // Pointer to array of pointers to real part of B matrix
ptr_B_imag.get(), // Pointer to array of pointers to imaginary part of B matrix
ldb, // Leading dimension of real part of B matrix
ldb, // Leading dimension of imaginary part of B matrix
&options.beta, // Pointer to beta scalar, of type complex<T>
cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices
ptr_C_real.get(), // Pointer to array of pointers to real part of C matrix
ptr_C_imag.get(), // Pointer to array of pointers to imaginary part of C matrix
ldc, // Leading dimension of real part of C matrix
ldc, // Leading dimension of imaginary part of C matrix
ptr_D_real.get(), // Pointer to array of pointers to real part of D matrix
ptr_D_imag.get(), // Pointer to array of pointers to imaginary part of D matrix
ldd, // Leading dimension of real part of D matrix
ldd // Leading dimension of imaginary part of D matrix
);
if (result.status != cutlass::Status::kSuccess) {
std::cerr << "CUTLASS internal error - configuration not supported" << std::endl;
return result;
}
}
//
// Stop profiling loop
//
// Record an event when the GEMM operations have been launched.
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)cudaEventDestroy(event);
}
if (handle.get_last_operation()) {
std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl;
}
//
// Compute reference in device code
//
if (options.reference_check) {
result.passed = true;
for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) {
cutlass::reference::device::GemmPlanarComplex<
Element, LayoutA,
Element, LayoutB,
Element, LayoutC,
ElementAccumulator
>(
problem_size,
options.alpha,
{tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A},
cutlass::ComplexTransform::kConjugate,
{tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B},
cutlass::ComplexTransform::kNone,
options.beta,
{tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C},
{tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D}
);
Element epsilon = 0.1_hf;
Element nonzero_floor = 0.1_hf;
result.passed = cutlass::reference::device::BlockCompareRelativelyEqual(
tensor_D.get() + idx * batch_stride_D,
tensor_D_ref.get() + idx * batch_stride_D,
batch_stride_D,
epsilon,
nonzero_floor
);
}
if (result.passed) {
std::cout << "Reference check passed." << std::endl;
}
else {
std::cerr << "Error - reference check failed." << std::endl;
}
}
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << " GFLOPs: " << result.gflops << std::endl;
return result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
//
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
//
// Volta Tensor Core operations are first available in CUDA 10.1 Toolkit.
//
// Turing Tensor Core operations are first available in CUDA 10.2 Toolkit.
//
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (props.major < 7) {
std::cerr << "Tensor Core operations must be run on a machine with compute capability at least 70."
<< std::endl;
return -1;
}
else if (props.major == 7 && props.minor <= 2) {
//
// If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example.
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
return -1;
}
}
else if (props.major == 7 && props.minor >= 5) {
//
// If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example.
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
return -1;
}
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
TestbedPlanarComplex testbed(options);
Result result = testbed.profile(options);
return result.passed ? 0 : -1;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -22,16 +22,14 @@
set(CUTLASS_EXAMPLES_COMMON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/common)
function(cutlass_example_add_executable)
function(cutlass_example_add_executable NAME)
set(options)
set(oneValueArgs)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cutlass_add_executable(${__UNPARSED_ARGUMENTS})
list(GET __UNPARSED_ARGUMENTS 0 NAME)
cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS})
target_link_libraries(
${NAME}
@ -59,7 +57,10 @@ foreach(EXAMPLE
05_batched_gemm
06_splitK_gemm
07_volta_tensorop_gemm
08_turing_tensorop_gemm)
08_turing_tensorop_gemm
10_planar_complex
11_planar_complex_array
)
add_subdirectory(${EXAMPLE})
add_dependencies(cutlass_examples ${EXAMPLE})

View File

@ -46,61 +46,68 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Specializations
// Determine the appropriate way to target PTX's "ldmatrix" instruction.
//
/////////////////////////////////////////////////////////////////////////////////////////////////
#if (__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ == 2)
#define CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED 1
#else
#define CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED 0
#endif
#if ! defined(CUDA_NVVM_GET_SHARED_POINTER_ENABLED)
#define CUDA_NVVM_GET_SHARED_POINTER_ENABLED (CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED)
#endif
#if ! defined(CUDA_LDMATRIX_SUPPORTED)
#define CUDA_LDMATRIX_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 2))
#endif
#if ! defined(CUDA_LDMATRIX_ENABLED)
#define CUDA_LDMATRIX_ENABLED (CUDA_LDMATRIX_SUPPORTED)
#define CUDA_LDMATRIX_ENABLED CUDA_LDMATRIX_SUPPORTED
#endif
#if (CUDA_LDMATRIX_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
#if CUDA_LDMATRIX_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
#define CUDA_LDMATRIX_ACTIVATED 1
#else
#define CUDA_LDMATRIX_ACTIVATED 0
#endif
#if defined(CUTLASS_GET_SMEM_POINTER)
// Use the existing implementation
#elif CUDA_NVVM_GET_SHARED_POINTER_ENABLED
#if ! defined(NVVM_GET_SMEM_POINTER)
#define NVVM_GET_SMEM_POINTER
/////////////////////////////////////////////////////////////////////////////////////////////////
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED)
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1))
#endif
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED)
#define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED
#endif
#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED
extern "C" {
//
// This NVVM intrinsic is subject to change in future versions of CUDA.
// Clients should not call it directly. Rather, they should use the
// cutlass::arch::ldsm<>() template.
//
__device__ uint32_t __nvvm_get_smem_pointer(void*);
//
// This NVVM intrinsic is subject to change in future versions of CUDA.
// Clients should not call it directly. Rather, they should use the
// cutlass::arch::ldsm<>() template.
//
__device__ uint32_t __nvvm_get_smem_pointer(void *);
}
#endif
#define CUTLASS_GET_SMEM_POINTER(ptr) __nvvm_get_smem_pointer((void*)ptr)
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED
/// CUTLASS helper to get SMEM pointer
inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) {
return __nvvm_get_smem_pointer(const_cast<void *>(ptr));
}
/// CUTLASS helper to get SMEM pointer
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
return __nvvm_get_smem_pointer(ptr);
}
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
template <>
inline __device__ void ldsm<layout::RowMajor, 1>(
Array<unsigned, 1> & D,
void const* ptr) {
#if CUDA_LDMATRIX_ACTIVATED
#if defined(CUDA_LDMATRIX_ACTIVATED)
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
unsigned addr = cutlass_get_smem_pointer(ptr);
int x;
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr));
@ -120,9 +127,9 @@ inline __device__ void ldsm<layout::RowMajor, 2>(
Array<unsigned, 2> & D,
void const* ptr) {
#if CUDA_LDMATRIX_ACTIVATED
#if defined(CUDA_LDMATRIX_ACTIVATED)
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
unsigned addr = cutlass_get_smem_pointer(ptr);
int x, y;
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr));
@ -142,9 +149,9 @@ inline __device__ void ldsm<layout::RowMajor, 4>(
Array<unsigned, 4> & D,
void const* ptr) {
#if CUDA_LDMATRIX_ACTIVATED
#if defined(CUDA_LDMATRIX_ACTIVATED)
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
unsigned addr = cutlass_get_smem_pointer(ptr);
int x, y, z, w;
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr));
@ -167,9 +174,10 @@ template <>
inline __device__ void ldsm<layout::ColumnMajor, 1>(
Array<unsigned, 1> & D,
void const* ptr) {
#if CUDA_LDMATRIX_ACTIVATED
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
unsigned addr = cutlass_get_smem_pointer(ptr);
int x;
asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr));
@ -189,9 +197,9 @@ inline __device__ void ldsm<layout::ColumnMajor, 2>(
Array<unsigned, 2> & D,
void const* ptr) {
#if CUDA_LDMATRIX_ACTIVATED
#if defined(CUDA_LDMATRIX_ACTIVATED)
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
unsigned addr = cutlass_get_smem_pointer(ptr);
int x, y;
asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr));
@ -211,9 +219,9 @@ inline __device__ void ldsm<layout::ColumnMajor, 4>(
Array<unsigned, 4> & D,
void const* ptr) {
#if CUDA_LDMATRIX_ACTIVATED
#if defined(CUDA_LDMATRIX_ACTIVATED)
unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr);
unsigned addr = cutlass_get_smem_pointer(ptr);
int x, y, z, w;
asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr));

View File

@ -30,7 +30,9 @@
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/arch/arch.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -49,6 +51,11 @@ struct OpMultiplyAddSaturate;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the complex multiply-add operation
struct OpMultiplyAddComplex;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the inner product is defined by (XOR, POPC)
struct OpXorPopc;

View File

@ -27,7 +27,11 @@
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "mma.h"
#include "cutlass/layout/matrix.h"
@ -84,6 +88,7 @@ struct Mma<
using FragmentC = Array<half_t, 8>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm70;
CUTLASS_HOST_DEVICE
void operator()(
@ -139,6 +144,7 @@ struct Mma<
using FragmentC = Array<half_t, 8>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm70;
CUTLASS_HOST_DEVICE
void operator()(
@ -194,6 +200,7 @@ struct Mma<
using FragmentC = Array<half_t, 8>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm70;
CUTLASS_HOST_DEVICE
void operator()(
@ -249,6 +256,7 @@ struct Mma<
using FragmentC = Array<half_t, 8>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm70;
CUTLASS_HOST_DEVICE
void operator()(
@ -310,6 +318,7 @@ struct Mma<
using FragmentC = Array<float, 8>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm70;
/// Multiply-add
CUTLASS_HOST_DEVICE
@ -385,6 +394,7 @@ struct Mma<
using FragmentC = Array<float, 8>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm70;
/// Multiply-add
CUTLASS_HOST_DEVICE
@ -460,6 +470,7 @@ struct Mma<
using FragmentC = Array<float, 8>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm70;
/// Multiply-add
CUTLASS_HOST_DEVICE
@ -535,6 +546,7 @@ struct Mma<
using FragmentC = Array<float, 8>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm70;
/// Multiply-add
CUTLASS_HOST_DEVICE

View File

@ -28,7 +28,11 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/arch/wmma.h"
@ -93,6 +97,7 @@ struct Mma<
using FragmentC = Array<half_t, 4>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm75;
CUTLASS_HOST_DEVICE
void operator()(
@ -154,6 +159,7 @@ struct Mma<
using FragmentC = Array<float, 4>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -215,6 +221,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -271,6 +278,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -327,6 +335,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -384,6 +393,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -446,6 +456,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAddSaturate;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -502,6 +513,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAddSaturate;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -558,6 +570,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAddSaturate;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -614,6 +627,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAddSaturate;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -676,6 +690,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -732,6 +747,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -788,6 +804,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -844,6 +861,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -906,6 +924,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAddSaturate;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -962,6 +981,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAddSaturate;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -1018,6 +1038,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAddSaturate;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -1074,6 +1095,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpMultiplyAddSaturate;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE
@ -1136,6 +1158,7 @@ struct Mma<
using FragmentC = Array<int, 2>;
using Operator = OpXorPopc;
using ArchTag = arch::Sm75;
/// Computes multiply-add
CUTLASS_HOST_DEVICE

View File

@ -68,24 +68,6 @@
namespace cutlass {
namespace arch {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// MemoryKind class (Shared vs. Global memory)
/////////////////////////////////////////////////////////////////////////////////////////////////
enum class MemoryKind {
kShared, // Data resides in shared memory
kGlobal // Data resides in global memory
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// WarpParams holds architecture-specific constants
/////////////////////////////////////////////////////////////////////////////////////////////////
struct WarpParams {
static int const kThreadsPerWarp = 32;
static int const kQuadsPerWarp = 8;
static int const kThreadsPerQuad = 4;
};
////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically maps cutlass data types => nvcuda::wmma data types
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -196,7 +178,6 @@ template <
struct Wmma;
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
} // namespace cutlass

View File

@ -1,105 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates exposing warp matrix multiply-add (WMMA) operations
*/
#pragma once
#include "cutlass/arch/wmma.h"
namespace cutlass {
namespace arch {
/////////////////////////////////////////////////////////////////////////////////////////////////
///
/// WMMA structures to enclose * PTX * instruction string
///
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
/// WMMA PTX string load for A, B, and C matrices
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
typename Element_, ///< Data type of elements
typename Layout_, ///< Layout of matrix (concept: MatrixLayout)
MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory
>
struct PtxWmmaLoadA;
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
typename Element_, ///< Data type of elements
typename Layout_, ///< Layout of matrix (concept: MatrixLayout)
MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory
>
struct PtxWmmaLoadB;
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
typename Element_, ///< Data type of elements
typename Layout_, ///< Layout of matrix (concept: MatrixLayout)
MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory
>
struct PtxWmmaLoadC;
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
/// WMMA Matrix multiply-add operation
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
typename ElementA_, ///< Data type of A elements
typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout)
typename ElementB_, ///< Data type of B elements
typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout)
typename ElementC_, ///< Element type of C matrix
typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout)
typename Operator = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc)
>
struct PtxWmma;
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
/// WMMA store for matrix D
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Shape_, ///< Size of the matrix product (concept: GemmShape)
typename Element_, ///< Data type of elements
typename Layout_, ///< Layout of matrix (concept: MatrixLayout)
MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory
>
struct PtxWmmaStoreD;
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -28,7 +28,11 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/layout/matrix.h"
////////////////////////////////////////////////////////////////////////////////
@ -68,6 +72,7 @@ struct Wmma<
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using Operator = cutlass::arch::OpMultiplyAdd;
using ArchTag = arch::Sm70;
// check supported wmma shape for the given multiplicand data types
static_assert(

View File

@ -28,7 +28,11 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/layout/matrix.h"
////////////////////////////////////////////////////////////////////////////////
@ -65,6 +69,7 @@ struct Wmma<
using ElementC = int32_t;
using LayoutC = LayoutC_;
using Operator = cutlass::arch::OpMultiplyAdd;
using ArchTag = arch::Sm72;
// check supported wmma shape for the given multiplicand data types
static_assert(
@ -145,6 +150,7 @@ struct Wmma<
using ElementC = int32_t;
using LayoutC = LayoutC_;
using Operator = cutlass::arch::OpMultiplyAdd;
using ArchTag = arch::Sm72;
// check supported wmma shape for the given multiplicand data types
static_assert(

View File

@ -28,7 +28,11 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/layout/matrix.h"
////////////////////////////////////////////////////////////////////////////////
@ -65,6 +69,7 @@ struct Wmma<
using ElementC = int32_t;
using LayoutC = LayoutC_;
using Operator = cutlass::arch::OpMultiplyAdd;
using ArchTag = arch::Sm75;
// check supported wmma shape for the given multiplicand data types
static_assert(
@ -143,6 +148,7 @@ struct Wmma<
using ElementC = int32_t;
using LayoutC = LayoutC_;
using Operator = cutlass::arch::OpXorPopc;
using ArchTag = arch::Sm75;
// check supported wmma shape for the given multiplicand data types
static_assert(

View File

@ -0,0 +1,97 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Array holding planar complex elements
template <typename Element_, int N>
struct ArrayPlanarComplex {
/// Underlying real element
using Element = Element_;
/// Number of logical elements
static size_t const kElements = N;
/// Underlying Fragment of real-valued elemenets
using ArrayReal = Array<Element, N>;
public:
/// Fragment of real-valued elements representing the real part
ArrayReal real;
/// Fragment of real-valued elements representing the imaginary part
ArrayReal imag;
public:
/// Ctor
CUTLASS_HOST_DEVICE
ArrayPlanarComplex() { }
/// Ctor
CUTLASS_HOST_DEVICE
ArrayPlanarComplex(
ArrayReal const &real_,
ArrayReal const &imag_
):
real(real_), imag(imag_) { }
/// Sets the array to zero efficiently
CUTLASS_HOST_DEVICE
void clear() {
real.clear();
imag.clear();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to deduce template arguments
template <typename Element, int N>
CUTLASS_HOST_DEVICE
ArrayPlanarComplex<Element, N>
make_ArrayPlanarComplex(Array<Element, N> const &real, Array<Element, N> const &imag) {
return ArrayPlanarComplex<Element, N>(real, imag);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -25,7 +25,11 @@
#pragma once
#include <cuComplex.h>
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/half.h"
@ -351,6 +355,16 @@ CUTLASS_HOST_DEVICE R norm_accumulate(complex<T> const &z, R const &accumulator)
static_cast<R>(imag(z)) * static_cast<R>(imag(z));
}
/// Returns the complex conjugate
CUTLASS_HOST_DEVICE float conj(float const &z) {
return z;
}
/// Returns the complex conjugate
CUTLASS_HOST_DEVICE double conj(double const &z) {
return z;
}
/// Returns the complex conjugate
template <typename T>
CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const &z) {
@ -414,6 +428,10 @@ CUTLASS_HOST_DEVICE complex<T> sin(complex<T> const &z) {
template <typename T>
struct RealType< complex<T> > {
using Type = T;
static complex<T> from_real(double x) {
return complex<T>(static_cast<T>(x));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -438,5 +456,18 @@ cutlass::complex<double> from_real<cutlass::complex<double> >(double r) {
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct is_complex {
static bool const value = false;
};
template <typename T>
struct is_complex<complex<T>> {
static bool const value = true;
};
//////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
//////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -28,6 +28,12 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <stdint.h>
#endif
#include "cutlass/cutlass.h"
namespace cutlass {

View File

@ -35,10 +35,21 @@ namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
#define CUTLASS_DEVICE __forceinline__ __device__
#elif defined(__CUDACC_RTC__)
#define CUTLASS_HOST_DEVICE __forceinline__ __device__
#define CUTLASS_DEVICE __forceinline__ __device__
#else
#define CUTLASS_HOST_DEVICE inline
#endif
/// Status code returned by CUTLASS operations
enum class Status {
kSuccess, ///< Operation was successful.
kErrorMisalignedOperand, ///< operands fail alignment requirements.
kErrorInvalidDataType, ///< DataType fails requirement.
kErrorInvalidLayout, ///< Layout fails alignment requirement.
kErrorInvalidProblem, ///< Specified problem size is not supported by operator.
kErrorNotSupported, ///< Operation is not supported on current device.
@ -48,12 +59,15 @@ enum class Status {
};
/// Convert cutlass status to status strings
static inline char const* cutlassGetStatusString(cutlass::Status status) {
CUTLASS_HOST_DEVICE
static char const* cutlassGetStatusString(cutlass::Status status) {
switch (status) {
case cutlass::Status::kSuccess:
return "Success";
case cutlass::Status::kErrorMisalignedOperand:
return "Error Misaligned Operand";
case cutlass::Status::kErrorInvalidDataType:
return "Error Invalid Data Type";
case cutlass::Status::kErrorInvalidLayout:
return "Error Invalid Layout";
case cutlass::Status::kErrorInvalidProblem:
@ -79,16 +93,6 @@ static inline char const* cutlassGetStatusString(cutlass::Status status) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
#define CUTLASS_DEVICE __forceinline__ __device__
#elif defined(__CUDACC_RTC__)
#define CUTLASS_HOST_DEVICE __forceinline__ __device__
#define CUTLASS_DEVICE __forceinline__ __device__
#else
#define CUTLASS_HOST_DEVICE inline
#endif
#define CUTLASS_ASSERT(x) assert(x)
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -115,6 +119,12 @@ static inline char const* cutlassGetStatusString(cutlass::Status status) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct Debug {
typename T::X x;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
static const int NUM_THREADS_PER_WARP = 32;
static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2;
@ -131,6 +141,14 @@ int LaneId() {
return ret;
}
/// Computes SM number the thread is running on
CUTLASS_DEVICE
int SmId() {
int ret;
asm ("mov.u32 %0, %%smid;" : "=r"(ret));
return ret;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -167,23 +167,23 @@ public:
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
/// Clamping constant value
ElementCompute const kClamp = ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
intermediate = max_accumulator(intermediate, -kClamp);
intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
ElementCompute const kClamp =
ElementCompute((1U << (sizeof_bits<ElementOutput>::value - 1)) - 1);
intermediate = max_accumulator(intermediate, -kClamp - ElementCompute(1));
intermediate = min_accumulator(intermediate, kClamp);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conditional guards to enable partial specialization for packed integers
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
/// Applies a linear combination operator to an array of elements then clamps the output before
/// converting to the output element type.
@ -320,7 +320,142 @@ public:
#endif // Conditional guards to enable partial specialization for packed integers
/////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements then clamps
/// the output before converting to the output element type.
///
/// D = alpha * accumulator + beta * source + uniform
///
/// Note: The below method only works for small k dimensions. The default
/// approach is above
/// TODO: Add logic to fallback to the default approach
template <
/// Data type used to load and store< tensors
typename ElementOutput_,
/// Number of elements computed per operation
int Count,
/// Rounding mode
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
class FastLinearCombinationClamp {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = int;
using ElementCompute = float;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
/// scales accumulators
ElementCompute alpha;
/// scales source tensor
ElementCompute beta;
/// pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *alpha_ptr;
/// pointer to source scalar - if not null, loads it from memory
ElementCompute const *beta_ptr;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute alpha, ElementCompute beta)
: alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
FastLinearCombinationClamp(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return beta_ != ElementCompute(0); }
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition) {
if (k_partition) {
beta_ = ElementCompute(1);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &accumulator,
FragmentOutput const &source,
ElementCompute uniform = ElementCompute(0)) const {
// Convert source to interal compute numeric type
FastNumericArrayConverter<ElementCompute, ElementOutput, kCount, Round>
source_converter;
FastNumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Compute linear scaling in floating point
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
minimum<ComputeFragment> min_accumulator;
maximum<ComputeFragment> max_accumulator;
// Float min-max
intermediate =
mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator,
intermediate); // D = alpha * Accum + X
/// Clamping constant value
ElementCompute const kClamp =
ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
intermediate = max_accumulator(intermediate, -kClamp);
intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
// Convert to destination numeric type
FastNumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue

View File

@ -0,0 +1,196 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination operations on planar-complex arrays
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/complex.h"
#include "cutlass/array_planar_complex.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to arrays of planar-complex elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
/// Note, as with most CUTLASS components for planar complex, the template arguments describe
/// the underlying real data type.
template <
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationPlanarComplex {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = ArrayPlanarComplex<ElementOutput, kCount>;
using FragmentAccumulator = ArrayPlanarComplex<ElementAccumulator, kCount>;
using ComputeFragment = ArrayPlanarComplex<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
complex<ElementCompute> alpha; ///< scales accumulators
complex<ElementCompute> beta; ///< scales source tensor
complex<ElementCompute> const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
complex<ElementCompute> const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
complex<ElementCompute> alpha,
complex<ElementCompute> beta
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
CUTLASS_HOST_DEVICE
Params(
complex<ElementCompute> const *alpha_ptr,
complex<ElementCompute> const *beta_ptr
): alpha(complex<ElementCompute>()), beta(complex<ElementCompute>()), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
};
private:
//
// Data members
//
complex<ElementCompute> alpha_;
complex<ElementCompute> beta_;
public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationPlanarComplex(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return beta_.real() != ElementCompute(0) || beta_.imag() != ElementCompute(0);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition) {
if (k_partition) {
beta_ = ElementCompute(1);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_source(
source_converter(source.real),
source_converter(source.imag));
ComputeFragment converted_accumulator(
accumulator_converter(accumulator.real),
accumulator_converter(accumulator.imag));
// Perform binary operations
ComputeFragment intermediate;
multiplies<Array<ElementCompute, kCount> > mul_op;
multiply_add<Array<ElementCompute, kCount> > mul_add_op;
// complex multiply: I = beta * C
intermediate.real = mul_op(beta_.real(), converted_source.real);
intermediate.imag = mul_op(beta_.real(), converted_source.imag);
intermediate.real = mul_add_op(-beta_.imag(), converted_source.imag, intermediate.real);
intermediate.imag = mul_add_op( beta_.imag(), converted_source.real, intermediate.imag);
// complex multiply-add: I = alpha * AB + I
intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real, intermediate.real);
intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag, intermediate.imag);
intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real);
intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return FragmentOutput(
destination_converter(intermediate.real),
destination_converter(intermediate.imag));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -180,9 +180,13 @@ public:
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conditional guards to enable partial specialization for packed integers
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
/// Applies a linear combination operator to an array of elements then clamps the output before
/// converting to the output element type.
///
@ -327,6 +331,8 @@ public:
}
};
#endif // Conditional guards to enable partial specialization for packed integers
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread

View File

@ -58,16 +58,25 @@ namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues for TensorOps.
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Specialization and defines sensible defaults for epilogues for complex*complex case
// 4 real-valued mma operations (Complex)
// A = (ar + j ai), B (br +j bi), D = AB
// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br)
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Epilouge Shape
typename Shape_,
/// Warp-level mma operator
typename WarpMmaTensorOp_,
/// Number of k partitions
int PartitionsK,
/// Epilogue output operator
typename OutputOp_,
int ElementsPerAccess
>
/// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load()
int ElementsPerAccess,
/// Multiply-add operator
typename Operator_ = arch::OpMultiplyAddComplex>
struct DefaultEpilogueComplexTensorOp {
using Shape = Shape_;
@ -75,6 +84,7 @@ struct DefaultEpilogueComplexTensorOp {
static int const kPartitionsK = PartitionsK;
using OutputOp = OutputOp_;
static int const kElementsPerAccess = ElementsPerAccess;
using Operator = Operator_;
using ElementOutput = typename OutputOp::ElementOutput;
using LayoutC = typename WarpMmaTensorOp::LayoutC;

View File

@ -0,0 +1,197 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Constructs a default epilogue for planar complex outputs.
This template reuses components for real-valued epilogues and applies them to planar complex
output matrices.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/array_planar_complex.h"
#include "cutlass/arch/arch.h"
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/epilogue_planar_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues.
template <
typename ThreadblockShape_,
typename WarpMma_,
typename OpcodeClass_,
typename ArchTag_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess
>
struct DefaultEpiloguePlanarComplex;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues.
template <
typename ThreadblockShape_,
typename WarpMmaOperator_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess
>
struct DefaultEpiloguePlanarComplex<
ThreadblockShape_,
WarpMmaOperator_,
arch::OpClassTensorOp,
arch::Sm70,
PartitionsK,
OutputOp_,
ElementsPerAccess> {
using RealEpilogue = DefaultEpilogueVoltaTensorOp<
ThreadblockShape_,
WarpMmaOperator_,
PartitionsK,
OutputOp_,
ElementsPerAccess
>;
using Epilogue = EpiloguePlanarComplex<
ThreadblockShape_,
WarpMmaOperator_,
PartitionsK,
typename RealEpilogue::OutputTileIterator,
typename RealEpilogue::AccumulatorFragmentIterator,
typename RealEpilogue::WarpTileIterator,
typename RealEpilogue::SharedLoadIterator,
OutputOp_,
typename RealEpilogue::Padding
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues.
template <
typename ThreadblockShape_,
typename WarpMmaOperator_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess
>
struct DefaultEpiloguePlanarComplex<
ThreadblockShape_,
WarpMmaOperator_,
arch::OpClassTensorOp,
arch::Sm75,
PartitionsK,
OutputOp_,
ElementsPerAccess> {
using RealEpilogue = DefaultEpilogueTensorOp<
ThreadblockShape_,
WarpMmaOperator_,
PartitionsK,
OutputOp_,
ElementsPerAccess
>;
using Epilogue = EpiloguePlanarComplex<
ThreadblockShape_,
WarpMmaOperator_,
PartitionsK,
typename RealEpilogue::OutputTileIterator,
typename RealEpilogue::AccumulatorFragmentIterator,
typename RealEpilogue::WarpTileIterator,
typename RealEpilogue::SharedLoadIterator,
OutputOp_,
typename RealEpilogue::Padding
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues.
template <
typename ThreadblockShape_,
typename WarpMmaOperator_,
typename ArchTag_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess
>
struct DefaultEpiloguePlanarComplex<
ThreadblockShape_,
WarpMmaOperator_,
arch::OpClassSimt,
ArchTag_,
PartitionsK,
OutputOp_,
ElementsPerAccess> {
using RealEpilogue = DefaultEpilogueSimt<
ThreadblockShape_,
WarpMmaOperator_,
OutputOp_,
ElementsPerAccess
>;
using Epilogue = EpiloguePlanarComplex<
ThreadblockShape_,
WarpMmaOperator_,
PartitionsK,
typename RealEpilogue::OutputTileIterator,
typename RealEpilogue::AccumulatorFragmentIterator,
typename RealEpilogue::WarpTileIterator,
typename RealEpilogue::SharedLoadIterator,
OutputOp_,
typename RealEpilogue::Padding
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -23,7 +23,7 @@
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
\brief Epilogue for threadblock scoped GEMMs using WMMA.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.

View File

@ -146,6 +146,54 @@ struct DefaultInterleavedThreadMapTensorOp {
////////////////////////////////////////////////////////////////////////////////
/// Defines the optimal thread map for TensorOp accumulator layouts
template <typename ThreadblockShape_, typename WarpShape_, int PartitionsK,
typename Element_, int ElementsPerAccess, int InterleavedK>
struct DefaultInterleavedConvThreadMapTensorOp {
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
static int const kPartitionsK = PartitionsK;
using Element = Element_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kInterleavedK = InterleavedK;
//
// Definitions
//
struct Detail {
/// Tensor Operations fundamentally perform operations on 8 rows
static int const kTensorOpRows = 8;
static int const kWarpSize = 32;
static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
!(ThreadblockShape::kM % WarpShape::kM),
"Divisibility");
/// Number of warps
using WarpCount =
gemm::GemmShape<ThreadblockShape::kM / WarpShape::kM,
ThreadblockShape::kN / WarpShape::kN, kPartitionsK>;
/// Number of participating threads
static int const kThreads = WarpCount::kCount * kWarpSize;
};
//
// ThreadMap
//
/// ThreadMap to be used by epilogue::MaskedTileIterator satisfying concept
/// InterleavedOutputTileThreadMap
using Type = InterleavedConvOutputTileThreadMap<
MatrixShape<Detail::WarpCount::kM, Detail::WarpCount::kN>,
MatrixShape<WarpShape::kM / Detail::kTensorOpRows,
WarpShape::kN / InterleavedK>,
Detail::kThreads, kElementsPerAccess, sizeof_bits<Element>::value>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@ -32,7 +32,11 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@ -74,7 +78,7 @@ template <
class Epilogue :
public EpilogueBase<
Shape_,
WarpMmaOperator_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
@ -84,7 +88,7 @@ public:
using Base = EpilogueBase<
Shape_,
WarpMmaOperator_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
@ -171,8 +175,9 @@ public:
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
int64_t imag_stride_dest = 0, ///< Arguments required for planar complex case - not used in real-valued case
int64_t imag_stride_src = 0) { ///<
typename OutputTileIterator::Fragment source_fragment;

View File

@ -32,7 +32,11 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/matrix_shape.h"
@ -58,7 +62,7 @@ namespace threadblock {
/// Base class for epilogues defining warp-level
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
typename WarpShape_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
@ -68,7 +72,7 @@ class EpilogueBase {
public:
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
using WarpShape = WarpShape_;
static int const kPartitionsK = PartitionsK;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
@ -83,11 +87,10 @@ public:
/// Accumulator element
using ElementAccumulator = typename AccumulatorTile::Element;
/// Number of warps
using WarpCount = gemm::GemmShape<
Shape::kM / WarpMmaOperator::Shape::kM,
Shape::kN / WarpMmaOperator::Shape::kN,
Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
kPartitionsK
>;
@ -144,24 +147,6 @@ public:
storage.data(),
Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
}
CUTLASS_DEVICE
void debug_print() {
if (threadIdx.x == 0) {
#pragma unroll 1
for (int r = 0; r < Shape::kRow; ++r) {
#pragma unroll 1
for (int c = 0; c < Shape::kColumn; ++c) {
printf("%d ", int(storage.data()[r * StorageShape::kColumn + c]));
}
printf("\n");
}
}
__syncthreads();
}
};
protected:

View File

@ -0,0 +1,397 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include <assert.h>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/array_planar_complex.h"
#include "cutlass/layout/vector.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator for planar-complex output representations.
///
/// Note, as with most CUTLASS components for planar complex, the template arguments describe
/// the underlying real data type.
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
typename OutputOp_, ///< Output operator
typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
>
class EpiloguePlanarComplex {
public:
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
/// Output layout is always row-major
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = ArrayPlanarComplex<
typename WarpMmaOperator::FragmentC::Element,
WarpMmaOperator::FragmentC::kElements
>;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
using OutputAccessType = Array<
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Shape of each warp-level operation
using WarpShape = typename WarpMmaOperator::Shape;
/// Number of warps
using WarpCount = gemm::GemmShape<
Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
kPartitionsK
>;
/// Shared memory allocation
struct SharedStorage {
//
// Type definitions
//
/// Element type of shared memory
using Element = typename WarpTileIterator::Element;
/// Tensor reference to shared memory allocation
using TensorRef = typename WarpTileIterator::TensorRef;
/// Layout of shared memory allocation
using Layout = typename WarpTileIterator::Layout;
/// Logical shape of the shared memory tile written to by all warps.
using Shape = MatrixShape<
WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK,
WarpCount::kN * WarpTileIterator::Shape::kColumn
>;
/// Shape of the shared memory allocation for the epilogue
using StorageShape = MatrixShape<
Shape::kRow + Padding::kRow,
Shape::kColumn + Padding::kColumn
>;
static int const kImaginaryStride = StorageShape::kCount;
//
// Data members
//
AlignedBuffer<Element, kImaginaryStride * 2> storage;
//
// Methods
//
/// Returns a pointer to the shared memory buffer
CUTLASS_DEVICE
Element *data() {
return storage.data();
}
/// Returns a tensor reference to the shared memory buffer
CUTLASS_DEVICE
TensorRef reference() {
return TensorRef(
storage.data(),
Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
}
};
private:
//
// Data members
//
SharedStorage &shared_storage_;
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
/// Stores a warp's fragment of accumulators to SMEM
WarpTileIterator warp_tile_iterator_;
public:
/// Constructor
CUTLASS_DEVICE
EpiloguePlanarComplex(
SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
):
shared_storage_(shared_storage),
shared_load_iterator_(shared_storage.reference(), thread_idx),
warp_tile_iterator_(shared_storage.reference(), lane_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
//
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
int warp_m = warp_mn % WarpCount::kM;
int warp_n = warp_mn / WarpCount::kM;
MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
warp_tile_iterator_.add_tile_offset(warp_offset);
}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator_real, ///< Tile iterator for destination
OutputTileIterator destination_iterator_imag, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator_real, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
OutputTileIterator source_iterator_imag) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
typename OutputTileIterator::Fragment source_fragment_real;
typename OutputTileIterator::Fragment source_fragment_imag;
if (!output_op.is_source_needed()) {
source_iterator_real.clear_mask();
source_iterator_imag.clear_mask();
}
source_fragment_real.clear();
source_fragment_imag.clear();
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator_real(accumulators.real);
AccumulatorFragmentIterator accum_fragment_iterator_imag(accumulators.imag);
//
// Iterate over accumulator tile
//
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
//
// Load the source
//
source_iterator_real.load(source_fragment_real);
source_iterator_imag.load(source_fragment_imag);
++source_iterator_real;
++source_iterator_imag;
//
// Convert and store fragment
//
__syncthreads();
typename AccumulatorFragmentIterator::Fragment accum_fragment_real;
typename AccumulatorFragmentIterator::Fragment accum_fragment_imag;
accum_fragment_iterator_real.load(accum_fragment_real);
accum_fragment_iterator_imag.load(accum_fragment_imag);
++accum_fragment_iterator_real;
++accum_fragment_iterator_imag;
this->warp_tile_iterator_.store(accum_fragment_real);
this->warp_tile_iterator_.store_with_pointer_offset(accum_fragment_imag, SharedStorage::kImaginaryStride);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment aligned_accum_fragment_real[kPartitionsK];
typename SharedLoadIterator::Fragment aligned_accum_fragment_imag[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment_real[0]);
shared_load_iterator_.load_with_pointer_offset(aligned_accum_fragment_imag[0], SharedStorage::kImaginaryStride);
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
static_assert(kPartitionsK == 1, "Sliced-K not supported for planar complex at this time");
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment_real;
typename OutputTileIterator::Fragment output_fragment_imag;
apply_output_operator_(
output_fragment_real,
output_fragment_imag,
output_op,
aligned_accum_fragment_real[0],
aligned_accum_fragment_imag[0],
source_fragment_real,
source_fragment_imag);
//
// Store the final result
//
destination_iterator_real.store(output_fragment_real);
destination_iterator_imag.store(output_fragment_imag);
++destination_iterator_real;
++destination_iterator_imag;
}
}
private:
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
typename OutputTileIterator::Fragment &output_fragment_real,
typename OutputTileIterator::Fragment &output_fragment_imag,
OutputOp const &output_op, ///< Output operator
typename SharedLoadIterator::Fragment const &aligned_accum_fragment_real,
typename SharedLoadIterator::Fragment const &aligned_accum_fragment_imag,
typename OutputTileIterator::Fragment const &source_fragment_real,
typename OutputTileIterator::Fragment const &source_fragment_imag) {
OutputAccessType *output_frag_real_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment_real);
OutputAccessType *output_frag_imag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment_imag);
AccumulatorAccessType const *compute_frag_real_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment_real);
AccumulatorAccessType const *compute_frag_imag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment_imag);
OutputAccessType const *source_frag_real_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment_real);
OutputAccessType const *source_frag_imag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment_imag);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
auto result_fragment = output_op(
make_ArrayPlanarComplex(compute_frag_real_ptr[i], compute_frag_imag_ptr[i]),
make_ArrayPlanarComplex(source_frag_real_ptr[i], source_frag_imag_ptr[i])
);
output_frag_real_ptr[i] = result_fragment.real;
output_frag_imag_ptr[i] = result_fragment.imag;
}
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -85,9 +85,6 @@ class InterleavedEpilogue {
using OutputTileIterator = OutputTileIterator_;
using OutputOp = OutputOp_;
/// Output layout is always row-major
using Layout = layout::ColumnMajorInterleaved<InterleavedK>;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;

View File

@ -437,11 +437,10 @@ struct OutputTileOptimalThreadMap {
/// - minimal address arithmetic
/// - minimal predicate calculations
///
template <typename WarpCount_, typename MmaCount_, int Threads,
template <typename WarpCount_, typename Iterations_, int Threads,
int ElementsPerAccess, int ElementSize>
struct InterleavedOutputTileThreadMap {
using WarpCount = WarpCount_;
using MmaCount = MmaCount_;
static int const kWarpSize = 32;
static int const kThreads = Threads;
@ -460,7 +459,7 @@ struct InterleavedOutputTileThreadMap {
// Output
//
using Iterations = MmaCount;
using Iterations = Iterations_;
using Delta = layout::PitchLinearShape<kWarpSize * kElementsPerAccess, 1>;
@ -491,6 +490,67 @@ struct InterleavedOutputTileThreadMap {
////////////////////////////////////////////////////////////////////////////////
/// Template metaprogram for partitioning a 4D interleaved layout across warps
/// to achieve several performance objectives:
///
/// - coalesced memory accesses in units of 64 Byte lines
/// - minimal address arithmetic
/// - minimal predicate calculations
///
template <typename WarpCount_, typename Iterations_, int Threads,
int ElementsPerAccess, int ElementSize>
struct InterleavedConvOutputTileThreadMap {
using WarpCount = WarpCount_;
static int const kWarpSize = 32;
static int const kThreads = Threads;
static int const kWarpCount = kThreads / kWarpSize;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kElementSize = ElementSize;
//
// Metaprogram computation
//
struct Detail {};
//
// Output
//
using Iterations = Iterations_;
using Delta = MatrixShape<kWarpSize / 4, 4 * kElementsPerAccess>;
/// Initial offset function
CUTLASS_HOST_DEVICE
static MatrixCoord initial_offset(int thread_idx) {
int warp_idx = thread_idx / kWarpSize;
int lane_idx = thread_idx % kWarpSize;
// Compute warp location
MatrixCoord warp_footprint{
Delta::kRow * Iterations::kRow,
Delta::kColumn * Iterations::kColumn,
};
MatrixCoord warp_offset{warp_idx % WarpCount::kRow,
warp_idx / WarpCount::kRow};
// Compute per-lane offset
MatrixCoord thread_offset_in_warp{lane_idx / 4,
(lane_idx % 4) * kElementsPerAccess};
MatrixCoord thread_offset_in_threadblock_tile =
warp_footprint * warp_offset + thread_offset_in_warp;
return thread_offset_in_threadblock_tile;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@ -36,9 +36,9 @@
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
@ -107,16 +107,16 @@ public:
// Data members
//
Index stride; ///< stride in bytes between rows
LongIndex stride; ///< stride in bytes between rows
Index increment_row; ///< increment quantity (in bytes) to advance when moving between rows
Index increment_group; ///< increment quantity (in bytes) to advance when moving to the next group
Index increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster
LongIndex increment_row; ///< increment quantity (in bytes) to advance when moving between rows
LongIndex increment_group; ///< increment quantity (in bytes) to advance when moving to the next group
LongIndex increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster
Index advance_row; ///< amount to add to move to the next 'row' position
Index advance_group; ///< amount to add to move to the next 'group' position
Index advance_cluster; ///< amount to add to move to the next 'cluster' position
Index advance_tile; ///< amount to add to move to the next 'tile'
LongIndex advance_row; ///< amount to add to move to the next 'row' position
LongIndex advance_group; ///< amount to add to move to the next 'group' position
LongIndex advance_cluster; ///< amount to add to move to the next 'cluster' position
LongIndex advance_tile; ///< amount to add to move to the next 'tile'
//
// Methods
@ -125,7 +125,7 @@ public:
CUTLASS_HOST_DEVICE
Status initialize(Index stride_) {
stride = stride_;
stride = LongIndex(stride_);
increment_row = stride * ThreadMap::Delta::kRow;
@ -261,8 +261,8 @@ public:
// Initialize pointer
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
thread_offset.row() * params_.stride +
thread_offset.column() * sizeof(AccessType) / kElementsPerAccess;
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
// Initialize internal state counter
state_[0] = state_[1] = state_[2] = 0;
@ -276,7 +276,7 @@ public:
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) {
void load_with_byte_offset(Fragment &frag, int64_t byte_offset) {
uint8_t *byte_pointer = byte_pointer_;
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
@ -299,7 +299,7 @@ public:
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
@ -328,9 +328,15 @@ public:
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) {
load_with_byte_offset(frag, 0);
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) {
void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) {
uint8_t *byte_pointer = byte_pointer_;
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
@ -352,7 +358,7 @@ public:
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
@ -382,6 +388,12 @@ public:
}
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) {
store_with_byte_offset(frag, 0);
}
/// Advances to the next position to load or store
CUTLASS_HOST_DEVICE
PredicatedTileIterator &operator++() {
@ -440,6 +452,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load output tile from shared memory in epilogue.
///
/// Satisfies: ReadableTileIterator | InterleavedPredicatedTileIterator | ForwardTileIterator
@ -447,7 +460,7 @@ public:
template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
int InterleavedK ///< Number of Interleaved K
int InterleavedN ///< Number of Interleaved N
>
class InterleavedPredicatedTileIterator {
public:
@ -455,7 +468,7 @@ public:
using Element = Element_;
using Layout = layout::ColumnMajorInterleaved<InterleavedK>;
using Layout = layout::ColumnMajorInterleaved<InterleavedN>;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
@ -483,10 +496,10 @@ public:
// Data members
//
Index stride; ///< stride in bytes between columns
LongIndex stride; ///< stride in bytes between columns
Index advance_row; ///< amount to add to move to the next 'row' position
Index advance_column; ///< amount to add to move to the next 'column' position
LongIndex advance_row; ///< amount to add to move to the next 'row' position
LongIndex advance_column; ///< amount to add to move to the next 'column' position
//
// Methods
@ -494,14 +507,16 @@ public:
CUTLASS_HOST_DEVICE
Status initialize(Index stride_) {
stride = stride_;
stride = LongIndex(stride_);
advance_row =
ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8;
advance_column =
stride_ - ThreadMap::Iterations::kContiguous * kElementsPerAccess *
sizeof_bits<Element>::value * ThreadMap::kWarpSize / 8;
advance_column = LongIndex(stride_) - ThreadMap::Iterations::kContiguous *
kElementsPerAccess *
sizeof_bits<Element>::value *
ThreadMap::kWarpSize / 8;
return Status::kSuccess;
}
@ -602,10 +617,10 @@ public:
):
params_(params) {
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) +
TensorCoord(threadblock_offset.contiguous() * InterleavedK,
threadblock_offset.strided() / InterleavedK);
TensorCoord(threadblock_offset.contiguous() * InterleavedN,
threadblock_offset.strided() / InterleavedN);
extent_col_ = extent.strided() / InterleavedK;
extent_col_ = extent.strided() / InterleavedN;
thread_start_col_ = thread_offset.strided();
// Initialize predicates
@ -613,13 +628,13 @@ public:
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
mask_.predicates[c] =
((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) <
(extent.contiguous() * InterleavedK));
(extent.contiguous() * InterleavedN));
}
// Initialize pointer
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
thread_offset.strided() * params_.stride +
thread_offset.contiguous() * sizeof(AccessType) / kElementsPerAccess;
LongIndex(thread_offset.strided()) * LongIndex(params_.stride) +
LongIndex(thread_offset.contiguous()) * sizeof(AccessType) / kElementsPerAccess;
// Initialize internal state counter
iteration_contiguous_ = iteration_strided_ = 0;
@ -634,6 +649,7 @@ public:
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) {
uint8_t *byte_pointer = byte_pointer_;
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);

View File

@ -100,12 +100,28 @@ public:
/// Padding quantity
using Padding = MatrixShape<
0,
4 * Policy::kElementsPerAccess>;
4 * Policy::kElementsPerAccess
#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES
+ 1
#endif
>;
private:
#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES
/// Storage type for accessing memory
using AccessType = AlignedArray<Element, Policy::kElementsPerAccess>;
using AccessType = AlignedArray<
Element,
1
>;
#else
/// Storage type for accessing memory
using AccessType = AlignedArray<
Element,
Policy::kElementsPerAccess
>;
#endif
//
// Data members
@ -130,18 +146,21 @@ public:
unsigned lane_id
):
pointer_(reinterpret_cast<AccessType *>(ref.data())),
layout_(ref.stride()[0] / Policy::kElementsPerAccess) {
layout_(ref.stride()[0] / AccessType::kElements) {
auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout();
MatrixCoord lane_offset = lane_layout.inverse(lane_id);
pointer_ += layout_(lane_offset);
pointer_ += layout_({
lane_offset.row(),
lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements)
});
}
/// Adds a pointer offset
CUTLASS_HOST_DEVICE
TileIteratorSimt & add_pointer_offset(Index pointer_offset) {
pointer_ += pointer_offset / Policy::kElementsPerAccess;
pointer_ += pointer_offset / AccessType::kElements;
return *this;
}
@ -151,7 +170,7 @@ public:
pointer_ += layout_({
tile_offset.row() * Shape::kRow,
(tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess)
(tile_offset.column() * Shape::kColumn / int(AccessType::kElements))
});
return *this;
@ -173,7 +192,7 @@ public:
// de-vectorized stores
using ScalarAccessType = AlignedArray<Element, 1>;
ScalarAccessType const *scalarFragPtr = reinterpret_cast<ScalarAccessType const *>(&frag);
ScalarAccessType *scalarPointer = reinterpret_cast<ScalarAccessType *>(pointer_);
ScalarAccessType *scalarPointer = reinterpret_cast<ScalarAccessType *>(pointer_) + pointer_offset;
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::kAccessesPerIteration; ++n) {
@ -187,7 +206,7 @@ public:
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::kAccessesPerIteration; ++n) {
pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn] = frag_ptr[n];
pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n];
}
#endif
}
@ -206,7 +225,7 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::kAccessesPerIteration; ++n) {
frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn];
frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)];
}
}

View File

@ -187,7 +187,8 @@ public:
int access = access_idx % 2;
int ptr_offset = tile_idx * InterleavedTileShape::kN / Policy::kElementsPerAccess +
access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess + access;
access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess +
access + pointer_offset / Policy::kElementsPerAccess;
int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx;
@ -219,7 +220,9 @@ public:
int access_quad = access_idx / 2;
int access = access_idx % 2;
int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta + access;
int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta +
access + pointer_offset / Policy::kElementsPerAccess;
int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx;
frag_ptr[frag_idx] = pointer_[ptr_offset];
@ -382,7 +385,7 @@ public:
int ptr_row_offset = row_idx * 2;
int ptr_offset = layout_({ptr_row_offset, ptr_column_offset});
int ptr_offset = layout_({ptr_row_offset, ptr_column_offset}) + pointer_offset / Policy::kElementsPerAccess;
pointer_[ptr_offset] = frag_ptr[frag_idx];
}

View File

@ -25,7 +25,12 @@
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#endif
#include "cutlass/cutlass.h"
/**
@ -201,6 +206,18 @@ void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul,
rem = src - (quo * div);
}
/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b
CUTLASS_HOST_DEVICE
int round_up(int a, int b) {
return ((a + b - 1) / b) * b;
}
/// Returns the ceiling of (a / b)
CUTLASS_HOST_DEVICE
int ceil_div(int a, int b) {
return (a + b - 1) / b;
}
/******************************************************************************
* Min/Max
******************************************************************************/

View File

@ -1,4 +1,4 @@
/***************************************************************************************************
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
@ -1207,4 +1207,8 @@ struct multiply_add<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -192,7 +192,8 @@ template <
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,
typename ThreadblockSwizzle_ =
typename threadblock::GemmCohortThreadblockSwizzle<LayoutA_, LayoutB_>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
@ -241,6 +242,8 @@ class Gemm {
static int const kAlignmentC = EpilogueOutputOp::kCount;
static bool const kSplitKSerial = SplitKSerial;
static bool const kIsBetaZero = IsBetaZero;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Define the kernel
using GemmKernel = typename kernel::DefaultGemm<
@ -348,21 +351,23 @@ public:
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
if (kSplitKSerial && args.split_k_slices > 1) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
return 0;
return bytes;
}
/// Initializes GEMM state from arguments.
@ -426,6 +431,7 @@ public:
params_.ref_B.reset(args.ref_B.non_const_ref().data());
params_.ref_C.reset(args.ref_C.non_const_ref().data());
params_.ref_D.reset(args.ref_D.data());
params_.output_op = args.epilogue;
params_.semaphore = static_cast<int *>(workspace);
return Status::kSuccess;
@ -560,6 +566,8 @@ class Gemm<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
static bool const kSplitKSerial = SplitKSerial;
static bool const kIsBetaZero = IsBetaZero;

View File

@ -0,0 +1,745 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_array.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
////////////////////////////////////////////////////////////////////////////////
/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may
be invoked from host code.
The contributions of this class are:
1. At compile time, it maps data types and high-level structural parameters onto
specific CUTLASS components.
2. At runtime, it maps logical arguments to GEMM problems to kernel parameters.
3. At runtime, it launches kernels on the device.
The intent is to provide a convenient mechanism for interacting with most plausible GEMM
configurations for each supported architecture. Consequently, not all parameters are exposed
to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy
are selected to tradeoff simplicity of the interface with flexibility. We expect
most configurations to be specified at this level. Applications with more exotic requirements
may construct their kernels of interest using CUTLASS components at the threadblock, warp,
and thread levels of abstraction.
CUTLASS exposes computations using the functor design pattern in which objects compose some
internal state with an overloaded function call operator. This enables decoupling of
initialization from execution, possibly reducing overhead during steady state phases of
application execution.
CUTLASS device-level operators expose an Arguments structure encompassing each logical
input to the computation. This is distinct from the kernel-level Params structure pattern
which contains application-specific precomputed state needed by the device code.
Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN
is as follows:
//
// Instantiate the CUTLASS GEMM operator.
//
cutlass::gemm::device::Gemm<
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::ColumnMajor
> gemm_op;
//
// Launch the GEMM operation on the device
//
cutlass::Status status = gemm_op({
{m, n, k}, // GemmCoord problem_size,
{A, lda}, // TensorRef<float, layout::ColumnMajor> ref_A,
{B, ldb}, // TensorRef<float, layout::ColumnMajor> ref_B,
{C, ldc}, // TensorRef<float, layout::ColumnMajor> ref_C,
{D, ldd}, // TensorRef<float, layout::ColumnMajor> ref_D,
{alpha, beta} // EpilogueOutputOp::Params epilogue_op_params
});
A simplified view of the template is listed below.
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages
>
class Gemm;
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator
>
class GemmArray {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
using Operator = Operator_;
/// Define the kernel
using DefaultGemmKernel = typename kernel::DefaultGemm<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
kStages,
false,
Operator,
false
>::GemmKernel;
using GemmKernel = kernel::GemmArray<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle>;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size;
ElementA const * const *ptr_A;
LayoutA layout_A;
ElementB const * const *ptr_B;
LayoutB layout_B;
ElementC const * const *ptr_C;
LayoutC layout_C;
ElementC * const * ptr_D;
LayoutC layout_D;
typename EpilogueOutputOp::Params epilogue;
int batch_count;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments() { }
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_,
ElementA const * const *ptr_A_,
LayoutA layout_A_,
ElementB const * const *ptr_B_,
LayoutB layout_B_,
ElementC const * const *ptr_C_,
LayoutC layout_C_,
ElementC * const * ptr_D_,
LayoutC layout_D_,
typename EpilogueOutputOp::Params epilogue_,
int batch_count_
):
problem_size(problem_size_),
ptr_A(ptr_A_),
layout_A(layout_A_),
ptr_B(ptr_B_),
layout_B(layout_B_),
ptr_C(ptr_C_),
layout_C(layout_C_),
ptr_D(ptr_D_),
layout_D(layout_D_),
epilogue(epilogue_),
batch_count(batch_count_) { }
};
private:
/// Kernel parameters object
typename GemmKernel::Params params_;
public:
/// Constructs the GEMM.
GemmArray() { }
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
if (args.layout_A.stride(0) % kAlignmentA) {
return Status::kErrorMisalignedOperand;
}
if (args.layout_B.stride(0) % kAlignmentB) {
return Status::kErrorMisalignedOperand;
}
if (args.layout_C.stride(0) % kAlignmentC) {
return Status::kErrorMisalignedOperand;
}
if (args.layout_D.stride(0) % kAlignmentC) {
return Status::kErrorMisalignedOperand;
}
if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) ||
(args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) ||
(args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return 0;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
args.batch_count,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
// Initialize the Params structure
params_ = typename GemmKernel::Params{
args.problem_size,
grid_shape,
args.ptr_A,
args.layout_A,
args.ptr_B,
args.layout_B,
args.ptr_C,
args.layout_C,
args.ptr_D,
args.layout_D,
args.epilogue,
args.batch_count
};
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
args.batch_count,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
params_ = typename GemmKernel::Params{
args.problem_size,
grid_shape,
args.ptr_A,
args.layout_A,
args.ptr_B,
args.layout_B,
args.ptr_C,
args.layout_C,
args.ptr_D,
args.layout_D,
args.epilogue,
args.batch_count
};
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
cudaError_t result;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
/// Parital specialization for column-major output exchanges problem size and operand.
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Epilogue output operator
typename EpilogueOutputOp_,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Access granularity of A matrix in units of elements
int AlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB,
typename Operator_
>
class GemmArray<
ElementA_,
LayoutA_,
ElementB_,
LayoutB_,
ElementC_,
layout::ColumnMajor,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
AlignmentA,
AlignmentB,
Operator_
> {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = layout::ColumnMajor;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
static bool const kSplitKSerial = false;
//
using UnderlyingOperator = GemmArray<
ElementB,
typename layout::LayoutTranspose<LayoutB>::type,
ElementA,
typename layout::LayoutTranspose<LayoutA>::type,
ElementC,
layout::RowMajor,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
kAlignmentB,
kAlignmentA
>;
using UnderlyingArguments = typename UnderlyingOperator::Arguments;
using GemmKernel = typename UnderlyingOperator::GemmKernel;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size;
ElementA const * const *ptr_A;
LayoutA layout_A;
ElementB const * const *ptr_B;
LayoutB layout_B;
ElementC const * const *ptr_C;
LayoutC layout_C;
ElementC * const * ptr_D;
LayoutC layout_D;
typename EpilogueOutputOp::Params epilogue;
int batch_count;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments() { }
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_,
ElementA const * const *ptr_A_,
LayoutA layout_A_,
ElementB const * const *ptr_B_,
LayoutB layout_B_,
ElementC const * const *ptr_C_,
LayoutC layout_C_,
ElementC * const * ptr_D_,
LayoutC layout_D_,
typename EpilogueOutputOp::Params epilogue_,
int batch_count_
):
problem_size(problem_size_),
ptr_A(ptr_A_),
layout_A(layout_A_),
ptr_B(ptr_B_),
layout_B(layout_B_),
ptr_C(ptr_C_),
layout_C(layout_C_),
ptr_D(ptr_D_),
layout_D(layout_D_),
epilogue(epilogue_),
batch_count(batch_count_) { }
};
private:
UnderlyingOperator underlying_operator_;
public:
/// Constructs the GEMM.
GemmArray() { }
/// Helper to construct a transposed equivalent for the underying GEMM operator
static UnderlyingArguments to_underlying_arguments(Arguments const &args) {
GemmCoord problem_size{
args.problem_size.n(),
args.problem_size.m(),
args.problem_size.k()
};
return UnderlyingArguments(
problem_size,
args.ptr_B,
args.layout_B.stride(),
args.ptr_A,
args.layout_A.stride(),
args.ptr_C,
args.layout_C.stride(),
args.ptr_D,
args.layout_D.stride(),
args.epilogue,
args.batch_count
);
}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
return underlying_operator_.update(to_underlying_arguments(args), workspace);
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
return underlying_operator_.run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -370,8 +370,8 @@ public:
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
args.batch_count,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.batch_count);
// Initialize the Params structure
params_ = typename GemmKernel::Params{

View File

@ -201,6 +201,8 @@ template <
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB = ComplexTransform::kNone,
/// Multiply-add operator
typename Operator_ = arch::OpMultiplyAddComplex,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false
>
@ -228,7 +230,11 @@ class GemmComplex {
static int const kStages = Stages;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Operator = Operator_;
static bool const kSplitKSerial = SplitKSerial;
static int const kAlignmentA = 1;
static int const kAlignmentB = 1;
static int const kAlignmentC = EpilogueOutputOp::kCount;
/// Define the kernel
using GemmKernel = typename kernel::DefaultGemmComplex<
@ -249,6 +255,7 @@ class GemmComplex {
kStages,
kTransformA,
kTransformB,
Operator,
kSplitKSerial
>::GemmKernel;
@ -498,6 +505,8 @@ template <
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Multiply-add operator
typename Operator_,
/// If true, kernel supports split-K as a serial reduction
bool SplitKSerial
>
@ -519,6 +528,7 @@ class GemmComplex<
Stages,
TransformA,
TransformB,
Operator_,
SplitKSerial
> {
public:
@ -542,6 +552,7 @@ public:
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static int const kStages = Stages;
using Operator = Operator_;
static bool const kSplitKSerial = SplitKSerial;
using UnderlyingOperator = GemmComplex<
@ -562,8 +573,15 @@ public:
Stages,
TransformA,
TransformB,
Operator,
SplitKSerial
>;
static int const kAlignmentA = UnderlyingOperator::kAlignmentB;
static int const kAlignmentB = UnderlyingOperator::kAlignmentA;
static int const kAlignmentC = UnderlyingOperator::kAlignmentC;
static ComplexTransform const kTransformA = UnderlyingOperator::kTransformB;
static ComplexTransform const kTransformB = UnderlyingOperator::kTransformA;
using UnderlyingArguments = typename UnderlyingOperator::Arguments;
using GemmKernel = typename UnderlyingOperator::GemmKernel;

View File

@ -0,0 +1,372 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/device/gemm_universal_base.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*!
The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB = ComplexTransform::kNone
>
class GemmUniversal :
GemmUniversalBase<
typename kernel::DefaultGemmUniversal<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_
>::GemmKernel
> {
public:
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Base = GemmUniversalBase<
typename kernel::DefaultGemmUniversal<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_
>::GemmKernel
>;
using Arguments = typename Base::Arguments;
using GemmKernel = typename Base::GemmKernel;
};
////////////////////////////////////////////////////////////////////////////////
/// Parital specialization for column-major output exchanges problem size and operand.
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Epilogue output operator
typename EpilogueOutputOp_,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Access granularity of A matrix in units of elements
int AlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB,
/// Operation performed by GEMM
typename Operator_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB>
class GemmUniversal<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
layout::ColumnMajor, // partially specialized on LayoutC
ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
WarpShape_, InstructionShape_, EpilogueOutputOp_,
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
Operator_, TransformA, TransformB> {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = layout::ColumnMajor;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using UnderlyingOperator = typename GemmUniversal<
ElementB,
typename layout::LayoutTranspose<LayoutB>::type,
ElementA,
typename layout::LayoutTranspose<LayoutA>::type,
ElementC,
layout::RowMajor,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
kAlignmentB,
kAlignmentA,
Operator,
kTransformB,
kTransformA
>::Base;
using GemmKernel = typename UnderlyingOperator::GemmKernel;
static int const kAlignmentC = EpilogueOutputOp::kCount;
/// Argument structure
using Arguments = typename UnderlyingOperator::Arguments;
private:
UnderlyingOperator underlying_operator_;
public:
/// Constructs the GEMM.
GemmUniversal() { }
/// Helper to construct a transposed equivalent for the underying GEMM operator
static Arguments to_underlying_arguments(Arguments const &args) {
return args.transposed_problem();
}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
return underlying_operator_.update(to_underlying_arguments(args), workspace);
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
return underlying_operator_.run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,172 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_base.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmKernel_>
class GemmUniversalAdapter {
public:
using GemmKernel = GemmKernel_;
static_assert(std::is_same<typename GemmKernel::LayoutC, cutlass::layout::RowMajor>::value,
"Universal adapter expects the kernel to be row-major and transposes its arguments.");
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using WarpShape = typename GemmKernel::WarpShape;
using InstructionShape = typename GemmKernel::InstructionShape;
using OperatorClass = typename GemmKernel::OperatorClass;
using ArchTag = typename GemmKernel::ArchTag;
// Type, layout, and complex transform deliberately exchanged with B
using ElementA = typename GemmKernel::ElementB;
using LayoutA = typename layout::LayoutTranspose<typename GemmKernel::LayoutB>::type;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = GemmKernel::kTransformB;
static int const kAlignmentA = GemmKernel::kAlignmentA;
// Type, layout, and complex transform deliberately exchanged with A
using ElementB = typename GemmKernel::ElementA;
using LayoutB = typename layout::LayoutTranspose<typename GemmKernel::LayoutA>::type;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = GemmKernel::kTransformA;
static int const kAlignmentB = GemmKernel::kAlignmentB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = cutlass::layout::ColumnMajor;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kAlignmentC = GemmKernel::kAlignmentC;
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
static int const kStages = GemmKernel::Mma::kStages;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using Operator = typename GemmKernel::Operator;
using UnderlyingOperator = GemmUniversalBase<GemmKernel>;
using Arguments = typename UnderlyingOperator::Arguments;
private:
UnderlyingOperator underlying_operator_;
public:
/// Constructs the GEMM.
GemmUniversalAdapter() { }
/// Helper to construct a transposed equivalent for the underying GEMM operator
static Arguments to_underlying_arguments(Arguments const &args) {
return args.transposed_problem();
}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
return underlying_operator_.update(to_underlying_arguments(args), workspace);
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
return underlying_operator_.run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,339 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmKernel_>
class GemmUniversalBase {
public:
using GemmKernel = GemmKernel_;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using ElementA = typename GemmKernel::ElementA;
using LayoutA = typename GemmKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
using ElementB = typename GemmKernel::ElementB;
using LayoutB = typename GemmKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename GemmKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using Operator = typename GemmKernel::Operator;
/// Argument structure
using Arguments = typename GemmKernel::Arguments;
protected:
/// Kernel parameters object
typename GemmKernel::Params params_;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static void get_grid_shape_(gemm::GemmCoord &grid_tiled_shape, int &gemm_k_size, Arguments const &args) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.batch_count);
gemm_k_size = args.problem_size.k();
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBase() { }
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return GemmKernel::can_implement(args);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t workspace_bytes = 0;
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
// Split-K parallel always requires a temporary workspace
workspace_bytes =
sizeof(ElementC) *
size_t(args.batch_stride_D) *
size_t(grid_tiled_shape.k());
}
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
return workspace_bytes;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
return threadblock_swizzle.get_grid_shape(grid_tiled_shape);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
int max_active_blocks = -1;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size <= (48 << 10)) {
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
Kernel<GemmKernel>,
GemmKernel::kThreadCount,
smem_size);
if (result == cudaSuccess) {
return max_active_blocks;
}
}
else {
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
Kernel<GemmKernel>,
GemmKernel::kThreadCount,
0);
if (result != cudaSuccess) {
return -1;
}
if (smem_capacity < 0) {
int device_idx = 0;
result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
return -1;
}
cudaDeviceProp properties;
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
return -1;
}
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
}
return std::min(max_active_blocks, smem_capacity / smem_size);
}
return -1;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
if (args.mode == GemmUniversalMode::kGemm) {
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
// Get CUDA grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
// Initialize the Params structure
params_ = typename GemmKernel::Params(
args,
grid_tiled_shape,
gemm_k_size,
static_cast<int *>(workspace)
);
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace) {
return Status::kErrorWorkspaceNull;
}
params_.update(args, workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
cudaError_t result;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -394,6 +394,15 @@ struct BatchedGemmCoord : public Coord<4, int> {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
enum class GemmUniversalMode {
kGemm,
kGemmSplitKParallel,
kBatched,
kArray
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm

View File

@ -0,0 +1,118 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm.h"
#include "cutlass/gemm/kernel/gemm_pipelined.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Multiply-add operator
typename Operator,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial
>
struct DefaultGemmComplex;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,229 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/kernel/gemm_planar_complex.h"
#include "cutlass/gemm/kernel/gemm_planar_complex_array.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/epilogue/threadblock/default_epilogue_planar_complex.h"
#include "cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Math operation performed by GEMM (e.g. arch::OpMultiplyAdd)
typename Operator,
/// Conditional enabling to switch between stages
typename Enable = void
>
struct DefaultGemmPlanarComplexUniversal;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for pipelined mainloop
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator
>
struct DefaultGemmPlanarComplexUniversal<
ElementA,
LayoutA,
TransformA,
kAlignmentA,
ElementB,
LayoutB,
TransformB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator,
typename std::enable_if<(Stages <= 2)>::type
> {
/// Define planar complex valued variants instead
using Mma = typename gemm::threadblock::DefaultMmaPlanarComplexPipelined<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
LayoutC,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
Stages,
TransformA,
TransformB,
Operator
>::ThreadblockMma;
/// Planar complex epilogue
using Epilogue = typename epilogue::threadblock::DefaultEpiloguePlanarComplex<
ThreadblockShape,
typename Mma::Policy::Operator,
OperatorClass,
ArchTag,
ThreadblockShape::kK / WarpShape::kK,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
/// Define the kernel in terms of the default kernel
using GemmKernel = kernel::GemmPlanarComplex<
Mma,
Epilogue,
ThreadblockSwizzle
>;
// Array variant
using GemmArrayKernel = kernel::GemmPlanarComplexArray<
Mma,
Epilogue,
ThreadblockSwizzle
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,308 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
///
typename Enable = void
>
struct DefaultGemmUniversal;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Real-valued GEMM kernels
//
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator>
struct DefaultGemmUniversal<
ElementA,
LayoutA,
ComplexTransform::kNone, // transform A
kAlignmentA,
ElementB,
LayoutB,
ComplexTransform::kNone, // transform B
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator,
typename std::enable_if< ! cutlass::is_complex<ElementAccumulator>::value>::type
> {
using DefaultGemmKernel = typename kernel::DefaultGemm<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
true,
Operator,
false
>::GemmKernel;
/// Define the kernel in terms of the default kernel
using GemmKernel = kernel::GemmUniversal<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
ThreadblockSwizzle
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Complex-valued GEMM kernels
//
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator
>
struct DefaultGemmUniversal<
ElementA,
LayoutA,
TransformA,
kAlignmentA,
ElementB,
LayoutB,
TransformB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator,
typename std::enable_if<cutlass::is_complex<ElementAccumulator>::value>::type
> {
using DefaultGemmKernel = typename kernel::DefaultGemmComplex<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
TransformA,
TransformB,
Operator,
false
>::GemmKernel;
/// Define the kernel in terms of the default kernel
using GemmKernel = kernel::GemmUniversal<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
ThreadblockSwizzle
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

0
include/cutlass/gemm/kernel/default_gemv.h Normal file → Executable file
View File

View File

@ -83,7 +83,7 @@ struct Gemm {
//
CUTLASS_HOST_DEVICE
Params() { }
Params(): semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
CUTLASS_HOST_DEVICE
Params(
@ -94,7 +94,7 @@ struct Gemm {
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
typename OutputOp::Params output_op = typename OutputOp::Params(),
int *semaphore = nullptr
int *workspace = nullptr
):
problem_size(problem_size),
grid_tiled_shape(grid_tiled_shape),
@ -106,13 +106,14 @@ struct Gemm {
ref_C(ref_C),
params_D(ref_D.layout()),
ref_D(ref_D),
output_op(output_op),
semaphore(semaphore) {
output_op(output_op) {
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
semaphore = workspace;
}
};
@ -220,7 +221,9 @@ struct Gemm {
thread_idx,
tb_offset_B);
int warp_idx = threadIdx.x / 32;
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//

View File

@ -0,0 +1,253 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmArray {
using Mma = Mma_;
using Epilogue = Epilogue_;
using OutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorA::Element const * const * ptr_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorB::Element const * const * ptr_B;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::Element const * const * ptr_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::OutputTileIterator::Element * const * ptr_D;
int64_t stride_D;
typename OutputOp::Params epilogue;
int batch_count;
int gemm_k_iterations;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size_,
cutlass::gemm::GemmCoord const & grid_tiled_shape_,
typename Mma::IteratorA::Element const * const * ptr_A_,
typename Mma::IteratorA::Layout layout_A,
typename Mma::IteratorB::Element const * const * ptr_B_,
typename Mma::IteratorB::Layout layout_B,
typename Epilogue::OutputTileIterator::Element const * const * ptr_C_,
typename Epilogue::OutputTileIterator::Layout layout_C,
typename Epilogue::OutputTileIterator::Element * const * ptr_D_,
typename Epilogue::OutputTileIterator::Layout layout_D,
typename OutputOp::Params epilogue_,
int batch_count_
):
problem_size(problem_size_),
grid_tiled_shape(grid_tiled_shape_),
params_A(layout_A),
ptr_A(ptr_A_),
params_B(layout_B),
ptr_B(ptr_B_),
params_C(layout_C),
ptr_C(ptr_C_),
params_D(layout_D),
ptr_D(ptr_D_),
epilogue(epilogue_),
batch_count(batch_count_),
gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) {
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
GemmArray() { }
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
// Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension
for (int batch_idx = threadblock_swizzle.get_batch_idx();
batch_idx < params.batch_count;
batch_idx += gridDim.z) {
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
0
};
cutlass::MatrixCoord tb_offset_B{
0,
threadblock_tile_offset.n() * Mma::Shape::kN
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
const_cast<typename Mma::IteratorA::Element *>(params.ptr_A[batch_idx]),
params.problem_size.mk(),
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
const_cast<typename Mma::IteratorB::Element *>(params.ptr_B[batch_idx]),
params.problem_size.kn(),
thread_idx,
tb_offset_B);
//
// Main loop
//
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
//
// Epilogue
//
OutputOp output_op(params.epilogue);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
// Tile iterator writing to output tile
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
const_cast<typename Epilogue::OutputTileIterator::Element *>(params.ptr_C[batch_idx]),
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
// Tile iterator writing to output tile
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
params.ptr_D[batch_idx],
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// run efficient epilogue
epilogue(output_op, iterator_D, accumulators, iterator_C);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -193,8 +193,10 @@ struct GemmBatched {
// Main loop
//
// Construct thread-scoped matrix multiply
int warp_idx = threadIdx.x / 32;
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

View File

@ -0,0 +1,693 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmPlanarComplex {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
using Operator = typename Mma::Operator;
using ArchTag = typename Mma::ArchTag;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementB>::value);
//
// Additional types needed for reflection
//
using ElementAccumulator = typename Mma::Policy::Operator::ElementC;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::Shape;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
//
// Arguments structure
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
typename EpilogueOutputOp::Params epilogue;
void const * ptr_A_real;
void const * ptr_A_imag;
void const * ptr_B_real;
void const * ptr_B_imag;
void const * ptr_C_real;
void const * ptr_C_imag;
void * ptr_D_real;
void * ptr_D_imag;
int lda_real;
int lda_imag;
int ldb_real;
int ldb_imag;
int ldc_real;
int ldc_imag;
int ldd_real;
int ldd_imag;
int64_t batch_stride_A;
int64_t batch_stride_A_imag;
int64_t batch_stride_B;
int64_t batch_stride_B_imag;
int64_t batch_stride_C;
int64_t batch_stride_C_imag;
int64_t batch_stride_D;
int64_t batch_stride_D_imag;
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1),
ptr_A_real(nullptr),
ptr_A_imag(nullptr),
ptr_B_real(nullptr),
ptr_B_imag(nullptr),
ptr_C_real(nullptr),
ptr_C_imag(nullptr),
ptr_D_real(nullptr),
ptr_D_imag(nullptr)
{ }
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params epilogue,
void const * ptr_A_real,
void const * ptr_A_imag,
void const * ptr_B_real,
void const * ptr_B_imag,
void const * ptr_C_real,
void const * ptr_C_imag,
void * ptr_D_real,
void * ptr_D_imag,
int lda_real,
int lda_imag,
int ldb_real,
int ldb_imag,
int ldc_real,
int ldc_imag,
int ldd_real,
int ldd_imag,
int64_t batch_stride_A = 0,
int64_t batch_stride_A_imag = 0,
int64_t batch_stride_B = 0,
int64_t batch_stride_B_imag = 0,
int64_t batch_stride_C = 0,
int64_t batch_stride_C_imag = 0,
int64_t batch_stride_D = 0,
int64_t batch_stride_D_imag = 0
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
epilogue(epilogue),
ptr_A_real(ptr_A_real),
ptr_A_imag(ptr_A_imag),
ptr_B_real(ptr_B_real),
ptr_B_imag(ptr_B_imag),
ptr_C_real(ptr_C_real),
ptr_C_imag(ptr_C_imag),
ptr_D_real(ptr_D_real),
ptr_D_imag(ptr_D_imag),
lda_real(lda_real),
lda_imag(lda_imag),
ldb_real(ldb_real),
ldb_imag(ldb_imag),
ldc_real(ldc_real),
ldc_imag(ldc_imag),
ldd_real(ldd_real),
ldd_imag(ldd_imag),
batch_stride_A(batch_stride_A),
batch_stride_A_imag(batch_stride_A_imag),
batch_stride_B(batch_stride_B),
batch_stride_B_imag(batch_stride_B_imag),
batch_stride_C(batch_stride_C),
batch_stride_C_imag(batch_stride_C_imag),
batch_stride_D(batch_stride_D),
batch_stride_D_imag(batch_stride_D_imag) {
}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_A_real, args.ptr_B_real);
std::swap(args.ptr_A_imag, args.ptr_B_imag);
std::swap(args.lda_real, args.ldb_real);
std::swap(args.lda_imag, args.ldb_imag);
std::swap(args.batch_stride_A, args.batch_stride_B);
std::swap(args.batch_stride_A_imag, args.batch_stride_B_imag);
return args;
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
typename Mma::IteratorA::Params params_A_real;
typename Mma::IteratorA::Params params_A_imag;
typename Mma::IteratorB::Params params_B_real;
typename Mma::IteratorB::Params params_B_imag;
typename Epilogue::OutputTileIterator::Params params_C_real;
typename Epilogue::OutputTileIterator::Params params_C_imag;
typename Epilogue::OutputTileIterator::Params params_D_real;
typename Epilogue::OutputTileIterator::Params params_D_imag;
typename EpilogueOutputOp::Params output_op;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void * ptr_A_real;
void * ptr_A_imag;
void * ptr_B_real;
void * ptr_B_imag;
void * ptr_C_real;
void * ptr_C_imag;
void * ptr_D_real;
void * ptr_D_imag;
int64_t batch_stride_A;
int64_t batch_stride_A_imag;
int64_t batch_stride_B;
int64_t batch_stride_B_imag;
int64_t batch_stride_C;
int64_t batch_stride_C_imag;
int64_t batch_stride_D;
int64_t batch_stride_D_imag;
int *semaphore;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A_real(nullptr),
ptr_A_imag(nullptr),
ptr_B_real(nullptr),
ptr_B_imag(nullptr),
ptr_C_real(nullptr),
ptr_C_imag(nullptr),
ptr_D_real(nullptr),
ptr_D_imag(nullptr),
batch_stride_A(0),
batch_stride_A_imag(0),
batch_stride_B(0),
batch_stride_B_imag(0),
batch_stride_C(0),
batch_stride_C_imag(0),
batch_stride_D(0),
batch_stride_D_imag(0),
semaphore(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
Arguments const &args,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
int gemm_k_size,
void *workspace = nullptr
):
problem_size(args.problem_size),
grid_tiled_shape(grid_tiled_shape),
params_A_real(args.lda_real),
params_A_imag(args.lda_imag),
params_B_real(args.ldb_real),
params_B_imag(args.ldb_imag),
params_C_real(args.ldc_real),
params_C_imag(args.ldc_imag),
params_D_real(args.ldd_real),
params_D_imag(args.ldd_imag),
output_op(args.epilogue),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(gemm_k_size),
ptr_A_real(const_cast<void *>(args.ptr_A_real)),
ptr_A_imag(const_cast<void *>(args.ptr_A_imag)),
ptr_B_real(const_cast<void *>(args.ptr_B_real)),
ptr_B_imag(const_cast<void *>(args.ptr_B_imag)),
ptr_C_real(const_cast<void *>(args.ptr_C_real)),
ptr_C_imag(const_cast<void *>(args.ptr_C_imag)),
ptr_D_real(args.ptr_D_real),
ptr_D_imag(args.ptr_D_imag),
batch_stride_A(args.batch_stride_A),
batch_stride_A_imag(args.batch_stride_A_imag),
batch_stride_B(args.batch_stride_B),
batch_stride_B_imag(args.batch_stride_B_imag),
batch_stride_C(args.batch_stride_C),
batch_stride_C_imag(args.batch_stride_C_imag),
batch_stride_D(args.batch_stride_D),
batch_stride_D_imag(args.batch_stride_D_imag),
semaphore(static_cast<int *>(workspace)) {
}
void update(
Arguments const &args,
void *workspace = nullptr) {
ptr_A_real = const_cast<void *>(args.ptr_A_real);
ptr_A_imag = const_cast<void *>(args.ptr_A_imag);
ptr_B_real = const_cast<void *>(args.ptr_B_real);
ptr_B_imag = const_cast<void *>(args.ptr_B_imag);
ptr_C_real = const_cast<void *>(args.ptr_C_real);
ptr_C_imag = const_cast<void *>(args.ptr_C_imag);
ptr_D_real = const_cast<void *>(args.ptr_D_real);
ptr_D_imag = const_cast<void *>(args.ptr_D_imag);
batch_stride_A = args.batch_stride_A;
batch_stride_A_imag = args.batch_stride_A_imag;
batch_stride_B = args.batch_stride_B;
batch_stride_B_imag = args.batch_stride_B_imag;
batch_stride_C = args.batch_stride_C;
batch_stride_C_imag = args.batch_stride_C_imag;
batch_stride_D = args.batch_stride_D;
batch_stride_D_imag = args.batch_stride_D_imag;
output_op = args.epilogue;
semaphore = static_cast<int *>(workspace);
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmPlanarComplex() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(Arguments const &args) {
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) ||
(args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) ||
(args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A_real = static_cast<ElementA *>(params.ptr_A_real);
ElementA *ptr_A_imag = static_cast<ElementA *>(params.ptr_A_imag);
ElementB *ptr_B_real = static_cast<ElementB *>(params.ptr_B_real);
ElementB *ptr_B_imag = static_cast<ElementB *>(params.ptr_B_imag);
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm ||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A;
ptr_A_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A_imag;
ptr_B_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B;
ptr_B_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B_imag;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_A_real = static_cast<ElementA * const *>(params.ptr_A_real)[threadblock_tile_offset.k()];
ptr_A_imag = static_cast<ElementA * const *>(params.ptr_A_imag)[threadblock_tile_offset.k()];
ptr_B_real = static_cast<ElementB * const *>(params.ptr_B_real)[threadblock_tile_offset.k()];
ptr_B_imag = static_cast<ElementB * const *>(params.ptr_B_imag)[threadblock_tile_offset.k()];
}
__syncthreads();
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{
offset_k,
threadblock_tile_offset.n() * Mma::Shape::kN
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A_real(
params.params_A_real,
ptr_A_real,
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorA iterator_A_imag(
params.params_A_imag,
ptr_A_imag,
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B_real(
params.params_B_real,
ptr_B_real,
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
typename Mma::IteratorB iterator_B_imag(
params.params_B_imag,
ptr_B_imag,
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A_real,
iterator_A_imag,
iterator_B_real,
iterator_B_imag,
accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C_real = static_cast<ElementC *>(params.ptr_C_real);
ElementC *ptr_C_imag = static_cast<ElementC *>(params.ptr_C_imag);
ElementC *ptr_D_real = static_cast<ElementC *>(params.ptr_D_real);
ElementC *ptr_D_imag = static_cast<ElementC *>(params.ptr_D_imag);
//
// Fetch pointers based on mode.
//
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
if (params.mode == GemmUniversalMode::kGemm) {
// If performing a reduction via split-K, fetch the initial synchronization
if (params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k());
}
}
else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
ptr_D_real += threadblock_tile_offset.k() * params.batch_stride_D;
ptr_D_imag += threadblock_tile_offset.k() * params.batch_stride_D_imag;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_C_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C;
ptr_C_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C_imag;
ptr_D_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D;
ptr_D_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D_imag;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_C_real = static_cast<ElementC * const *>(params.ptr_C_real)[threadblock_tile_offset.k()];
ptr_C_imag = static_cast<ElementC * const *>(params.ptr_C_imag)[threadblock_tile_offset.k()];
ptr_D_real = static_cast<ElementC * const *>(params.ptr_D_real)[threadblock_tile_offset.k()];
ptr_D_imag = static_cast<ElementC * const *>(params.ptr_D_imag)[threadblock_tile_offset.k()];
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C_real(
params.params_C_real,
ptr_C_real,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
typename Epilogue::OutputTileIterator iterator_C_imag(
params.params_C_imag,
ptr_C_imag,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D_real(
params.params_D_real,
ptr_D_real,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
typename Epilogue::OutputTileIterator iterator_D_imag(
params.params_D_imag,
ptr_D_imag,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
//
// Construct epilogue
//
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C_real = iterator_D_real;
iterator_C_imag = iterator_D_imag;
}
semaphore.wait(threadblock_tile_offset.k());
__threadfence();
}
// Execute the epilogue operator to update the destination tensor.
epilogue(
output_op,
iterator_D_real,
iterator_D_imag,
accumulators,
iterator_C_real,
iterator_C_imag);
//
// Release the semaphore
//
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,583 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmPlanarComplexArray {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
using Operator = typename Mma::Operator;
using ArchTag = typename Mma::ArchTag;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementB>::value);
//
// Additional types needed for reflection
//
using ElementAccumulator = typename Mma::Policy::Operator::ElementC;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::Shape;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
//
// Arguments structure
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
typename EpilogueOutputOp::Params epilogue;
int const *ptr_M;
int const *ptr_N;
int const *ptr_K;
void const * const * ptr_A_real;
void const * const * ptr_A_imag;
void const * const * ptr_B_real;
void const * const * ptr_B_imag;
void const * const * ptr_C_real;
void const * const * ptr_C_imag;
void * const * ptr_D_real;
void * const * ptr_D_imag;
int lda_real;
int lda_imag;
int ldb_real;
int ldb_imag;
int ldc_real;
int ldc_imag;
int ldd_real;
int ldd_imag;
int64_t batch_stride_D; // unused
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kArray),
batch_count(1),
ptr_M(nullptr),
ptr_N(nullptr),
ptr_K(nullptr),
ptr_A_real(nullptr),
ptr_A_imag(nullptr),
ptr_B_real(nullptr),
ptr_B_imag(nullptr),
ptr_C_real(nullptr),
ptr_C_imag(nullptr),
ptr_D_real(nullptr),
ptr_D_imag(nullptr),
batch_stride_D(0)
{ }
/// constructs an arguments structure
Arguments(
GemmCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params epilogue,
int const *ptr_M,
int const *ptr_N,
int const *ptr_K,
void const * const * ptr_A_real,
void const * const * ptr_A_imag,
void const * const * ptr_B_real,
void const * const * ptr_B_imag,
void const * const * ptr_C_real,
void const * const * ptr_C_imag,
void * const * ptr_D_real,
void * const * ptr_D_imag,
int lda_real,
int lda_imag,
int ldb_real,
int ldb_imag,
int ldc_real,
int ldc_imag,
int ldd_real,
int ldd_imag
):
mode(GemmUniversalMode::kArray),
problem_size(problem_size),
batch_count(batch_count),
epilogue(epilogue),
ptr_M(ptr_M),
ptr_N(ptr_N),
ptr_K(ptr_K),
ptr_A_real(ptr_A_real),
ptr_A_imag(ptr_A_imag),
ptr_B_real(ptr_B_real),
ptr_B_imag(ptr_B_imag),
ptr_C_real(ptr_C_real),
ptr_C_imag(ptr_C_imag),
ptr_D_real(ptr_D_real),
ptr_D_imag(ptr_D_imag),
lda_real(lda_real),
lda_imag(lda_imag),
ldb_real(ldb_real),
ldb_imag(ldb_imag),
ldc_real(ldc_real),
ldc_imag(ldc_imag),
ldd_real(ldd_real),
ldd_imag(ldd_imag),
batch_stride_D(0) {
}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_M, args.ptr_N);
std::swap(args.ptr_A_real, args.ptr_B_real);
std::swap(args.ptr_A_imag, args.ptr_B_imag);
std::swap(args.lda_real, args.ldb_real);
std::swap(args.lda_imag, args.ldb_imag);
return args;
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
typename Mma::IteratorA::Params params_A_real;
typename Mma::IteratorA::Params params_A_imag;
typename Mma::IteratorB::Params params_B_real;
typename Mma::IteratorB::Params params_B_imag;
typename Epilogue::OutputTileIterator::Params params_C_real;
typename Epilogue::OutputTileIterator::Params params_C_imag;
typename Epilogue::OutputTileIterator::Params params_D_real;
typename Epilogue::OutputTileIterator::Params params_D_imag;
typename EpilogueOutputOp::Params output_op;
int batch_count;
int const *ptr_M;
int const *ptr_N;
int const *ptr_K;
void const * const * ptr_A_real;
void const * const * ptr_A_imag;
void const * const * ptr_B_real;
void const * const * ptr_B_imag;
void const * const * ptr_C_real;
void const * const * ptr_C_imag;
void * const * ptr_D_real;
void * const * ptr_D_imag;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
batch_count(0),
ptr_M(nullptr),
ptr_N(nullptr),
ptr_K(nullptr),
ptr_A_real(nullptr),
ptr_A_imag(nullptr),
ptr_B_real(nullptr),
ptr_B_imag(nullptr),
ptr_C_real(nullptr),
ptr_C_imag(nullptr),
ptr_D_real(nullptr),
ptr_D_imag(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
Arguments const &args,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
int gemm_k_size = 0, // ignored
void *workspace = nullptr // ignored
):
problem_size(args.problem_size),
grid_tiled_shape(grid_tiled_shape),
ptr_M(args.ptr_M),
ptr_N(args.ptr_N),
ptr_K(args.ptr_K),
params_A_real(args.lda_real),
params_A_imag(args.lda_imag),
params_B_real(args.ldb_real),
params_B_imag(args.ldb_imag),
params_C_real(args.ldc_real),
params_C_imag(args.ldc_imag),
params_D_real(args.ldd_real),
params_D_imag(args.ldd_imag),
output_op(args.epilogue),
batch_count(args.batch_count),
ptr_A_real(args.ptr_A_real),
ptr_A_imag(args.ptr_A_imag),
ptr_B_real(args.ptr_B_real),
ptr_B_imag(args.ptr_B_imag),
ptr_C_real(args.ptr_C_real),
ptr_C_imag(args.ptr_C_imag),
ptr_D_real(args.ptr_D_real),
ptr_D_imag(args.ptr_D_imag) {
}
void update(
Arguments const &args,
void *workspace = nullptr) {
ptr_M = args.ptr_M;
ptr_N = args.ptr_N;
ptr_K = args.ptr_K;
ptr_A_real = args.ptr_A_real;
ptr_A_imag = args.ptr_A_imag;
ptr_B_real = args.ptr_B_real;
ptr_B_imag = args.ptr_B_imag;
ptr_C_real = args.ptr_C_real;
ptr_C_imag = args.ptr_C_imag;
ptr_D_real = args.ptr_D_real;
ptr_D_imag = args.ptr_D_imag;
output_op = args.epilogue;
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmPlanarComplexArray() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(Arguments const &args) {
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) ||
(args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) ||
(args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
int batch_idx = threadblock_tile_offset.k();
int problem_size_m = params.problem_size.m();
int problem_size_n = params.problem_size.n();
int problem_size_k = params.problem_size.k();
ElementA *ptr_A_real = static_cast<ElementA *>(const_cast<void *>(params.ptr_A_real[batch_idx]));
ElementA *ptr_A_imag = static_cast<ElementA *>(const_cast<void *>(params.ptr_A_imag[batch_idx]));
ElementB *ptr_B_real = static_cast<ElementB *>(const_cast<void *>(params.ptr_B_real[batch_idx]));
ElementB *ptr_B_imag = static_cast<ElementB *>(const_cast<void *>(params.ptr_B_imag[batch_idx]));
//
// If pointers for problem sizes are specified, these are loaded from global memory
//
if (params.ptr_M) {
problem_size_m = params.ptr_M[batch_idx];
}
if (params.ptr_N) {
problem_size_n = params.ptr_N[batch_idx];
}
if (params.ptr_K) {
problem_size_k = params.ptr_K[batch_idx];
}
int const kBlockCountM = (problem_size_m + Mma::Shape::kM - 1) / Mma::Shape::kM;
int const kBlockCountN = (problem_size_n + Mma::Shape::kN - 1) / Mma::Shape::kN;
int const kGemmKIterations = (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
//
// Each threadblock loops over the logical problem size which the kernel may have discovered
// after the grid is launched.
//
CUTLASS_PRAGMA_NO_UNROLL
for (int block_m = threadblock_tile_offset.m();
block_m < kBlockCountM;
block_m += params.grid_tiled_shape.m()) {
CUTLASS_PRAGMA_NO_UNROLL
for (int block_n = threadblock_tile_offset.n();
block_n < kBlockCountN;
block_n += params.grid_tiled_shape.n()) {
//
// Compute indices within threadblock and warp.
//
int thread_idx = threadIdx.x;
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Proceed with regular GEMM logic.
//
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{ block_m * Mma::Shape::kM, 0};
cutlass::MatrixCoord tb_offset_B{ 0, block_n * Mma::Shape::kN };
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A_real(
params.params_A_real,
ptr_A_real,
{problem_size_m, problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorA iterator_A_imag(
params.params_A_imag,
ptr_A_imag,
{problem_size_m, problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B_real(
params.params_B_real,
ptr_B_real,
{problem_size_k, problem_size_n},
thread_idx,
tb_offset_B);
typename Mma::IteratorB iterator_B_imag(
params.params_B_imag,
ptr_B_imag,
{problem_size_k, problem_size_n},
thread_idx,
tb_offset_B);
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
mma(
kGemmKIterations,
accumulators,
iterator_A_real,
iterator_A_imag,
iterator_B_real,
iterator_B_imag,
accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
//assume identity swizzle
MatrixCoord threadblock_offset(
block_m * Mma::Shape::kM,
block_n * Mma::Shape::kN
);
ElementC *ptr_C_real = static_cast<ElementC *>(const_cast<void *>(params.ptr_C_real[batch_idx]));
ElementC *ptr_C_imag = static_cast<ElementC *>(const_cast<void *>(params.ptr_C_imag[batch_idx]));
ElementC *ptr_D_real = static_cast<ElementC *>(params.ptr_D_real[batch_idx]);
ElementC *ptr_D_imag = static_cast<ElementC *>(params.ptr_D_imag[batch_idx]);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C_real(
params.params_C_real,
ptr_C_real,
{problem_size_m, problem_size_n},
thread_idx,
threadblock_offset
);
typename Epilogue::OutputTileIterator iterator_C_imag(
params.params_C_imag,
ptr_C_imag,
{problem_size_m, problem_size_n},
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D_real(
params.params_D_real,
ptr_D_real,
{problem_size_m, problem_size_n},
thread_idx,
threadblock_offset
);
typename Epilogue::OutputTileIterator iterator_D_imag(
params.params_D_imag,
ptr_D_imag,
{problem_size_m, problem_size_n},
thread_idx,
threadblock_offset
);
//
// Construct epilogue
//
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(
output_op,
iterator_D_real,
iterator_D_imag,
accumulators,
iterator_C_real,
iterator_C_imag);
} // for block_n
} // for block_m
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,537 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmUniversal {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::Shape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
typename EpilogueOutputOp::Params epilogue;
void const * ptr_A;
void const * ptr_B;
void const * ptr_C;
void * ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_D;
int lda;
int ldb;
int ldc;
int ldd;
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1),
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { }
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params epilogue,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
int lda,
int ldb,
int ldc,
int ldd
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) {
}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_A, args.ptr_B);
std::swap(args.lda, args.ldb);
std::swap(args.batch_stride_A, args.batch_stride_B);
return args;
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename EpilogueOutputOp::Params output_op;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void * ptr_A;
void * ptr_B;
void * ptr_C;
void * ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_D;
int *semaphore;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
params_A(0),
params_B(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0),
batch_stride_C(0),
batch_stride_D(0),
semaphore(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
Arguments const &args,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
int gemm_k_size,
void *workspace = nullptr
):
problem_size(args.problem_size),
grid_tiled_shape(grid_tiled_shape),
params_A(args.lda),
params_B(args.ldb),
params_C(args.ldc),
params_D(args.ldd),
output_op(args.epilogue),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(gemm_k_size),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_C(const_cast<void *>(args.ptr_C)),
ptr_D(args.ptr_D),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D),
semaphore(static_cast<int *>(workspace)) {
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr) {
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
output_op = args.epilogue;
semaphore = static_cast<int *>(workspace);
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmUniversal() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size) {
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm ||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
}
__syncthreads();
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{
offset_k,
threadblock_tile_offset.n() * Mma::Shape::kN
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
ptr_A,
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
ptr_B,
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
//
// Fetch pointers based on mode.
//
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
if (params.mode == GemmUniversalMode::kGemm) {
// If performing a reduction via split-K, fetch the initial synchronization
if (params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k());
}
}
else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_C += threadblock_tile_offset.k() * params.batch_stride_C;
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_C = static_cast<ElementC * const *>(params.ptr_C)[threadblock_tile_offset.k()];
ptr_D = static_cast<ElementC * const *>(params.ptr_D)[threadblock_tile_offset.k()];
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
ptr_C,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
ptr_D,
params.problem_size.mn(),
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
__threadfence();
}
// Execute the epilogue operator to update the destination tensor.
epilogue(
output_op,
iterator_D,
accumulators,
iterator_C);
//
// Release the semaphore
//
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

0
include/cutlass/gemm/kernel/gemv_batched_strided.h Normal file → Executable file
View File

0
include/cutlass/gemm/threadblock/default_gemv_core.h Normal file → Executable file
View File

View File

@ -33,6 +33,7 @@
#include "cutlass/arch/arch.h"
#include "cutlass/arch/wmma.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
@ -143,8 +144,9 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
layout::RowMajor, typename MmaCore::MmaPolicy>;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass Simt)
/// Specialization for row-major output (OperatorClass TensorOp)
template <
/// Element type for A matrix operand
typename ElementA,
@ -199,8 +201,8 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
layout::RowMajor, typename MmaCore::MmaPolicy>;
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Specialization for column-major-interleaved output
template <
/// Element type for A matrix operand
@ -268,7 +270,9 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Specialization for SIMT IDP4A Kernels
template <
/// Layout type for A matrix operand
@ -326,6 +330,8 @@ struct DefaultMma<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
layout::RowMajor, typename MmaCore::MmaPolicy>;
};
////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
/// Specialization for Wmma TensorOp operator with 2 staged pipeline
template <
@ -384,6 +390,8 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
LayoutC, typename MmaCore::MmaPolicy>;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for Wmma TensorOp operator with 1 staged pipeline
template <
///< Element type for A matrix operand
@ -440,6 +448,7 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
LayoutC, typename MmaCore::MmaPolicy>;
};
////////////////////////////////////////////////////////////////////////////////
#endif //CUTLASS_ARCH_WMMA_ENABLED

View File

@ -598,6 +598,8 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: column-major-interleave32

View File

@ -0,0 +1,124 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/gemm/warp/mma_planar_complex.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/mma_planar_complex_pipelined.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Complex transformation on operand A
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex transformation on operand B
ComplexTransform TransformB = ComplexTransform::kNone,
/// Math operator tag (e.g. arch::OpMultiplyAdd)
typename Operator = arch::OpMultiplyAdd
>
struct DefaultMmaPlanarComplexPipelined {
// Construct a planar complex variant from the real-valued variant
using RealMma = typename DefaultMma<
ElementA_,
LayoutA_,
kAlignmentA,
ElementB_,
LayoutB_,
kAlignmentB,
ElementAccumulator_,
LayoutC_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
Stages,
Operator
>::ThreadblockMma;
using ThreadblockMma = MmaPlanarComplexPipelined<
ThreadblockShape_,
typename RealMma::IteratorA,
typename RealMma::SmemIteratorA,
typename RealMma::IteratorB,
typename RealMma::SmemIteratorB,
ElementAccumulator_,
LayoutC_,
typename RealMma::Policy,
Stages,
TransformA,
TransformB
>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

0
include/cutlass/gemm/threadblock/gemv.h Normal file → Executable file
View File

View File

@ -118,6 +118,15 @@ public:
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
@ -256,8 +265,8 @@ public:
__syncthreads();
++this->smem_iterator_B_;
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1) {
@ -299,7 +308,8 @@ public:
}
}
warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
warp_mma(accum, warp_frag_A[warp_mma_k % 2],
warp_frag_B[warp_mma_k % 2], accum);
}
}
@ -311,3 +321,5 @@ public:
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,201 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class MmaPlanarComplexBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages +
Policy::SmemPaddingA::kColumn>;
/// Stride to the imaginary part of the A operand
static int const kImaginaryStrideA = ShapeA::kCount;
/// Shape of the B matrix operand in shared memory
using ShapeB =
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
/// Stride to the imaginary part of the A operand
static int const kImaginaryStrideB = ShapeB::kCount;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount + kImaginaryStrideA> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount + kImaginaryStrideB> operand_B;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
MmaPlanarComplexBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,422 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_planar_complex_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Transformation applied to A
ComplexTransform TransformA = ComplexTransform::kNone,
/// Transformation applied to B
ComplexTransform TransformB = ComplexTransform::kNone
>
class MmaPlanarComplexPipelined :
public MmaPlanarComplexBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = MmaPlanarComplexBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using ArchTag = typename Policy::Operator::ArchTag;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
/// Transformation applied to A
static ComplexTransform const kTransformA = TransformA;
/// Transformation applied to B
static ComplexTransform const kTransformB = TransformB;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = ArrayPlanarComplex<
typename Policy::Operator::FragmentC::Element,
Policy::Operator::FragmentC::kElements
>;
/// Warp-level Mma
using Operator = typename Policy::Operator;
private:
using FragmentA = typename IteratorA::Fragment;
using FragmentB = typename IteratorB::Fragment;
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
MmaPlanarComplexPipelined(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
private:
CUTLASS_DEVICE
void warp_mma_planar_complex(
Operator & warp_mma,
FragmentC &accum,
WarpFragmentA const & real_A,
WarpFragmentA const & imag_A,
WarpFragmentB const & real_B,
WarpFragmentB const & imag_B) {
cutlass::negate<Array<typename WarpFragmentB::Element, WarpFragmentB::kElements>> neg_op_B;
WarpFragmentB neg_real_B = neg_op_B(real_B);
WarpFragmentB neg_imag_B = neg_op_B(imag_B);
warp_mma(accum.real, real_A, real_B, accum.real);
if (kTransformB == ComplexTransform::kNone) {
warp_mma(accum.imag, real_A, imag_B, accum.imag);
}
else {
warp_mma(accum.imag, real_A, neg_imag_B, accum.imag);
}
if (kTransformA == ComplexTransform::kNone) {
warp_mma(accum.imag, imag_A, real_B, accum.imag);
}
else {
warp_mma(accum.imag, imag_A, neg_real_B, accum.imag);
}
if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) {
warp_mma(accum.real, imag_A, imag_B, accum.real);
}
else {
warp_mma(accum.real, imag_A, neg_imag_B, accum.real);
}
}
public:
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC &accum,
///< iterator over A operand in global memory
IteratorA iterator_A_real,
///< iterator over A operand in global memory
IteratorA iterator_A_imag,
///< iterator over B operand in global memory
IteratorB iterator_B_real,
///< iterator over B operand in global memory
IteratorB iterator_B_imag,
///< initial value of accumulator
FragmentC const &src_accum) {
//
// Prologue
//
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A_real;
FragmentA tb_frag_A_imag;
FragmentB tb_frag_B_real;
FragmentB tb_frag_B_imag;
tb_frag_A_real.clear();
tb_frag_A_imag.clear();
tb_frag_B_real.clear();
tb_frag_B_imag.clear();
// The last kblock is loaded in the prolog
iterator_A_real.load(tb_frag_A_real);
iterator_A_imag.load(tb_frag_A_imag);
iterator_B_real.load(tb_frag_B_real);
iterator_B_imag.load(tb_frag_B_imag);
++iterator_A_real;
++iterator_A_imag;
++iterator_B_real;
++iterator_B_imag;
this->smem_iterator_A_.store(tb_frag_A_real);
this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA);
this->smem_iterator_B_.store(tb_frag_B_real);
this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB);
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_real_A[2];
WarpFragmentA warp_frag_imag_A[2];
WarpFragmentB warp_frag_real_B[2];
WarpFragmentB warp_frag_imag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_real_A[0]);
this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA);
this->warp_tile_iterator_B_.load(warp_frag_real_B[0]);
this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
if (gemm_k_iterations <= 1) {
iterator_A_real.clear_mask();
iterator_A_imag.clear_mask();
iterator_B_real.clear_mask();
iterator_B_imag.clear_mask();
}
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_A_.store(tb_frag_A_real);
this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA);
this->smem_iterator_B_.store(tb_frag_B_real);
this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB);
__syncthreads();
++this->smem_iterator_B_;
++this->smem_iterator_A_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
}
else {
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,
0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA);
this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (warp_mma_k == 0) {
iterator_A_real.load(tb_frag_A_real);
iterator_A_imag.load(tb_frag_A_imag);
iterator_B_real.load(tb_frag_B_real);
iterator_B_imag.load(tb_frag_B_imag);
++iterator_A_real;
++iterator_A_imag;
++iterator_B_real;
++iterator_B_imag;
// Avoid reading out of bounds if this was the last loop iteration
if (gemm_k_iterations <= 2) {
iterator_A_real.clear_mask();
iterator_A_imag.clear_mask();
iterator_B_real.clear_mask();
iterator_B_imag.clear_mask();
}
}
warp_mma_planar_complex(
warp_mma,
accum,
warp_frag_real_A[warp_mma_k % 2],
warp_frag_imag_A[warp_mma_k % 2],
warp_frag_real_B[warp_mma_k % 2],
warp_frag_imag_B[warp_mma_k % 2]);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -168,7 +168,6 @@ public:
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
@ -183,8 +182,9 @@ public:
++iterator_B;
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
WarpFragmentA warp_frag_A;
WarpFragmentB warp_frag_B;
Operator warp_mma;
// Avoid reading out of bounds
@ -193,7 +193,6 @@ public:
iterator_B.clear_mask();
}
//
// Mainloop
//
@ -203,7 +202,6 @@ public:
this->smem_iterator_A_.store(tb_frag_A);
this->smem_iterator_B_.store(tb_frag_B);
__syncthreads();
//
@ -216,16 +214,16 @@ public:
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k) % 2]);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k) % 2]);
this->warp_tile_iterator_A_.load(warp_frag_A);
this->warp_tile_iterator_B_.load(warp_frag_B);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
warp_mma(accum, warp_frag_A, warp_frag_B, accum);
}
// Add negative offsets to return smem load iterators to the 'start' of the shared memory

View File

@ -30,7 +30,8 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -141,6 +142,55 @@ struct GemmIdentityThreadblockSwizzle {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// A special version of GemmIdentityThreadblockSwizzle. See the choice of kTile below.
template <typename LayoutA_, typename LayoutB_>
struct GemmCohortThreadblockSwizzle
{
const int kTile =
(platform::is_same<LayoutA_, cutlass::layout::RowMajor>::value ||
platform::is_same<LayoutB_, cutlass::layout::ColumnMajor>::value)
? 4
: 1;
CUTLASS_HOST_DEVICE
GemmCohortThreadblockSwizzle() { }
/// Returns the shape of the problem in units of logical tiles
CUTLASS_HOST_DEVICE
GemmCoord get_tiled_shape(
GemmCoord problem_size,
GemmCoord tile_size,
int split_k_slices) const {
return GemmCoord(
(problem_size.m() + tile_size.m() - 1) / tile_size.m(),
(problem_size.n() + tile_size.n() - 1) / tile_size.n(),
split_k_slices);
}
/// Computes CUDA grid dimensions given a size in units of logical tiles
CUTLASS_HOST_DEVICE
dim3 get_grid_shape(GemmCoord tiled_shape) const {
return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k());
}
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
CUTLASS_DEVICE
GemmCoord get_tile_offset() const {
int block_idx_x = RematerializeBlockIdxX();
int block_idx_y = RematerializeBlockIdxY();
return GemmCoord{
(block_idx_x / kTile),
(block_idx_y * kTile) + (block_idx_x % kTile),
RematerializeBlockIdxZ()
};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Threadblock swizzling function for GEMMs
struct GemmHorizontalThreadblockSwizzle {
@ -186,8 +236,8 @@ struct GemmBatchedIdentityThreadblockSwizzle {
CUTLASS_HOST_DEVICE
GemmCoord get_tiled_shape(
GemmCoord problem_size,
int batch_count,
GemmCoord tile_size) const {
GemmCoord tile_size,
int batch_count) const {
return GemmCoord(
(problem_size.m() + tile_size.m() - 1) / tile_size.m(),
@ -207,7 +257,7 @@ struct GemmBatchedIdentityThreadblockSwizzle {
return GemmCoord{
RematerializeBlockIdxX(),
RematerializeBlockIdxY(),
0
RematerializeBlockIdxZ()
};
}

View File

@ -114,3 +114,6 @@ struct DefaultMmaTensorOp {
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,351 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/complex.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename RealElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename RealElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename RealElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Complex transform on A operand
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex transform on B operand
ComplexTransform TransformB = ComplexTransform::kNone,
/// Used for partial specialization
typename Enable = bool
>
class MmaComplexTensorOp;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename RealElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename RealElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename RealElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Complex transform on A operand
ComplexTransform TransformA,
/// Complex transform on B operand
ComplexTransform TransformB,
/// Used for partial specialization
typename Enable
>
class MmaComplexTensorOp<
Shape_,
complex<RealElementA>,
LayoutA_,
complex<RealElementB>,
LayoutB_,
complex<RealElementC>,
LayoutC_,
Policy_,
TransformA,
TransformB,
Enable> {
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = complex<RealElementA>;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = complex<RealElementB>;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = complex<RealElementC>;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Complex transform on A operand
static ComplexTransform const kTransformA = TransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = TransformB;
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
public:
/// Iterates over the A operand in memory
using IteratorA = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kM, Shape::kK>,
Operand::kA,
ElementA,
LayoutA,
MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,
Policy::OpDelta::kRow,
32,
1
>;
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Iterates over the B operand in memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kK, Shape::kN>,
Operand::kB,
ElementB,
LayoutB,
MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
Policy::OpDelta::kColumn,
32,
1
>;
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
static_assert(
!(Shape::kM % Policy::Operator::Shape::kM) &&
!(Shape::kN % Policy::Operator::Shape::kN),
"Shape of warp-level Mma must be divisible by operator shape.");
/// Number of mma operations performed
using MmaIterations = MatrixShape<
Shape::kM / Policy::Operator::Shape::kM,
Shape::kN / Policy::Operator::Shape::kN
>;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<
MatrixShape<Shape::kM, Shape::kN>,
ElementC,
LayoutC,
typename Policy::Operator::Shape,
typename Policy::OpDelta>;
/// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this
/// storage arrangement is to be considered 'planar complex' in the sense that all real-valued
/// parts are stored consecutively followed by all imaginary parts. This matches the structure
/// of Tensor Cores which are always real-valued matrix multiplies.
using FragmentC = typename IteratorC::Fragment;
static_assert(
FragmentC::kElements == 2 * MmaIterations::kCount * Policy::Operator::FragmentC::kElements,
"Unexpected planar complex fragment length.");
private:
//
// Data members
//
/// Underlying real-valued matrix multiply operator (concept: arch::Mma)
typename Policy::Operator mma;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaComplexTensorOp() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(
FragmentC &D,
FragmentA const &A,
FragmentB const &B,
FragmentC const &C) const {
// Alias types for underlying real-valued matrix multiply operator
using MmaOperandA = typename Policy::Operator::FragmentA;
using MmaOperandB = typename Policy::Operator::FragmentB;
using MmaOperandC = typename Policy::Operator::FragmentC;
static_assert(MmaOperandA::kElements == 1,
"This implementation only supports math instructions in which exactly one element is needed for the A operand."
"We can geneneralize later.");
static_assert(MmaOperandB::kElements == 1,
"This implementation only supports math instructions in which exactly one element is needed for the A operand."
"We can geneneralize later.");
D = C;
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m) {
// mma(accum.real(), a.real(), b.real(), accum.real());
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n) {
// Pack operands together. This may result in actual MOVs
MmaOperandA operand_A;
MmaOperandB operand_B;
operand_A[0] = A[m].real();
operand_B[0] = B[n].real();
// Real-valued accumulator part
MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
(m + n * MmaIterations::kRow);
mma(*accum, operand_A, operand_B, *accum);
}
// mma(accum.imag(), a.real(), b.imag(), accum.imag());
CUTLASS_PRAGMA_UNROLL
for (int n = MmaIterations::kColumn - 1; n >= 0; --n) {
// Pack operands together. This may result in actual MOVs
MmaOperandA operand_A;
MmaOperandB operand_B;
operand_A[0] = A[m].real();
operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag());
// Complex-valued accumulator part
MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
(m + n * MmaIterations::kRow) + MmaIterations::kCount;
mma(*accum, operand_A, operand_B, *accum);
}
// mma(accum.real(), -a.imag(), b.imag(), accum.real())
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n) {
// Pack operands together. This may result in actual MOVs
MmaOperandA operand_A;
MmaOperandB operand_B;
// A imaginary part is intentionally negated
operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? A[m].imag() : -A[m].imag());
operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag());
// Complex-valued accumulator part
MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
(m + n * MmaIterations::kRow);
mma(*accum, operand_A, operand_B, *accum);
}
// mma(accum.imag(), a.imag(), b.real(), accum.imag())
CUTLASS_PRAGMA_UNROLL
for (int n = MmaIterations::kColumn - 1; n >= 0; --n) {
// Pack operands together. This may result in actual MOVs
MmaOperandA operand_A;
MmaOperandB operand_B;
operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? -A[m].imag() : A[m].imag());
operand_B[0] = B[n].real();
// Real-valued accumulator part
MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
(m + n * MmaIterations::kRow) + MmaIterations::kCount;
mma(*accum, operand_A, operand_B, *accum);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// TODO - partial specializations of real*complex and complex*real
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,176 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/complex.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/array_planar_complex.h"
#include "cutlass/gemm/warp/tile_iterator_planar_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Underlying real-valued warp-level matrix multiply
typename Operator_,
/// Transformation applied to A operand (typically folded into math instruction)
ComplexTransform TransformA = ComplexTransform::kNone,
/// Transformation applied to B operand (typically folded into math instruction)
ComplexTransform TransformB = ComplexTransform::kNone
>
class MmaPlanarComplex {
public:
/// Underlying real-valued warp-level matrix multiply
using Operator = Operator_;
/// Shape of warp-level matrix multipy
using Shape = typename Operator::Shape;
/// Transformation applied to A operand (typically folded into math instruction)
static ComplexTransform const kTransformA = TransformA;
/// Transformation applied to B operand (typically folded into math instruction)
static ComplexTransform const kTransformB = TransformB;
/// Fragment of elements
using FragmentA = ArrayPlanarComplex<typename Operator::ElementA, Operator::FragmentA::kElements>;
/// Iterator into planar complex
using IteratorA = TileIteratorPlanarComplex<typename Operator::IteratorA>;
/// Layout in memory of the A operand
using LayoutA = typename Operator::LayoutA;
using FragmentB = ArrayPlanarComplex<typename Operator::ElementB, Operator::FragmentB::kElements>;
/// Iterator into planar complex
using IteratorB = TileIteratorPlanarComplex<typename Operator::IteratorB>;
/// Layout in memory of the B operand
using LayoutB = typename Operator::LayoutB;
/// Tile iterator for accumulator
using IteratorC = TileIteratorPlanarComplex<typename Operator::IteratorC>;
/// Accumulator fragment
using FragmentC = ArrayPlanarComplex<typename Operator::ElementC, Operator::FragmentC::kElements>;
/// Layout of accumulator fragment in memory
using LayoutC = typename Operator::LayoutC;
private:
/// Number of mma operations performed
using MmaIterations = MatrixShape<
Operator::Shape::kM / Operator::Policy::Operator::Shape::kM,
Operator::Shape::kN / Operator::Policy::Operator::Shape::kN
>;
public:
/// Ctor
CUTLASS_DEVICE
MmaPlanarComplex() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(
FragmentC &D,
FragmentA const &A_in,
FragmentB const &B_in,
FragmentC const &C) const {
D.real = C.real;
D.imag = C.imag;
//
// Transform fragments based on conjugate operations.
//
negate<typename FragmentA::ArrayReal> neg_A;
FragmentA frag_A;
frag_A.real = A_in.real;
if (kTransformA == ComplexTransform::kConjugate) {
frag_A.imag = neg_A(frag_A.imag);
}
else {
frag_A.imag = frag_A.imag;
}
FragmentB frag_B;
frag_B.real = B_in.real;
if (kTransformB == ComplexTransform::kConjugate) {
negate<typename FragmentB::ArrayReal> neg;
frag_B.imag = neg(frag_B.imag);
}
else {
frag_B.imag = frag_B.imag;
}
//
// Accumulated real-valued matrix multiplies
//
Operator real_mma;
// D.i += A.i * B.r
real_mma(D.imag, frag_A.imag, frag_B.real, D.imag);
// D.r += A.r * B.r
real_mma(D.real, frag_A.real, frag_B.real, D.real);
// D.i += A.r * B.i
real_mma(D.imag, frag_A.real, frag_B.imag, D.imag);
// D.r += -A.i * B.i
frag_A.imag = neg_A(frag_A.imag);
real_mma(D.real, frag_A.imag, frag_B.imag, D.real);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -100,6 +100,16 @@ public:
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassSimt;
/// Hard-coded for now
using ArchTag = arch::Sm50;
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Layout of threads
using ThreadLayoutA = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA >::value,
layout::ColumnMajor,
typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value,
@ -153,6 +163,9 @@ public:
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile
using TransformedFragmentA = FragmentA;
/// Iterates over the B operand in memory
using IteratorB = MmaSimtTileIterator<
MatrixShape<Policy::LaneMmaShape::kK, Shape::kN>,
@ -167,6 +180,9 @@ public:
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed A tile
using TransformedFragmentB = FragmentB;
/// Iterates over the C operand in memory
using IteratorC = MmaSimtTileIterator<
MatrixShape<Shape::kM, Shape::kN>,
@ -201,6 +217,15 @@ public:
mma(d, a, b, c);
}
/// Transform the mma operands to the required types
CUTLASS_DEVICE
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
FragmentA const &A, FragmentB const &B) const {
//TODO: Implement this
dst_A = A;
dst_B = B;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -31,7 +31,9 @@
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/platform/platform.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
@ -51,6 +53,60 @@ namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <typename T, typename S, int N, FloatRoundStyle Round>
struct ConvertAndPack {
using Converter = NumericArrayConverter<T, S, N, Round>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<S, N> const &source) {
Converter converter;
return converter(source);
}
};
template <typename T, int N, FloatRoundStyle Round>
struct ConvertAndPack<T, T, N, Round> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &source) {
return source;
}
};
template <int N, FloatRoundStyle Round>
struct ConvertAndPack<half_t, float, N, Round> {
using Converter = NumericArrayConverter<half_t, float, N, Round>;
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<float, N> const &source) {
Converter converter;
Array<float, N> tmp;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc));
tmp[i] = source[idx];
}
return converter(tmp);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
@ -105,9 +161,18 @@ public:
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Architecture tag from underlying instruction
using ArchTag = typename Policy::Operator::ArchTag;
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
@ -128,6 +193,10 @@ public:
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile
using TransformedFragmentA =
Array<typename Policy::Operator::ElementA, FragmentA::kElements>;
/// Iterates over the B operand in memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
@ -137,6 +206,10 @@ public:
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed B tile
using TransformedFragmentB =
Array<typename Policy::Operator::ElementB, FragmentB::kElements>;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<
MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
@ -179,8 +252,8 @@ public:
CUTLASS_DEVICE
void operator()(
FragmentC &D,
FragmentA const &A,
FragmentB const &B,
TransformedFragmentA const &A,
TransformedFragmentB const &B,
FragmentC const &C,
int const &partitionN_idx = 0) const {
@ -221,6 +294,44 @@ public:
}
}
}
/// Transform the mma operands to the required types
CUTLASS_DEVICE
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
FragmentA const &A, FragmentB const &B) const {
bool midway_depstage =
!(platform::is_same<typename Policy::Operator::ElementA,
ElementA>::value &&
platform::is_same<typename Policy::Operator::ElementB,
ElementB>::value);
//
// Define conversions from source type to instruction type
//
FloatRoundStyle const kRoundA =
PreferredRoundingMode<typename Policy::Operator::ElementA,
ElementA>::kRound;
FloatRoundStyle const kRoundB =
PreferredRoundingMode<typename Policy::Operator::ElementB,
ElementB>::kRound;
detail::ConvertAndPack<typename Policy::Operator::ElementA, ElementA,
FragmentA::kElements, kRoundA>
convert_A;
NumericArrayConverter<typename Policy::Operator::ElementB, ElementB,
FragmentB::kElements / 2, kRoundB>
convert_B;
Array<ElementB, FragmentB::kElements / 2> const *ptr_B =
reinterpret_cast<Array<ElementB, FragmentB::kElements / 2> const *>(&B);
Array<typename Policy::Operator::ElementB, FragmentB::kElements / 2> *
ptr_dst_B = reinterpret_cast<Array<typename Policy::Operator::ElementB,
FragmentB::kElements / 2> *>(&dst_B);
dst_A = convert_A(A);
ptr_dst_B[0] = convert_B(ptr_B[0]);
ptr_dst_B[1] = convert_B(ptr_B[1]);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -228,3 +339,5 @@ public:
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -103,6 +103,15 @@ public:
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Architecture tag
using ArchTag = arch::Sm70;
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;

View File

@ -199,7 +199,8 @@ public:
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads>;
using Fragment =
Array<Element, Shape::kContiguous * InstructionShape::kStrided / kThreads>;
private:
@ -516,7 +517,7 @@ class MmaTensorOpMultiplicandTileIterator<
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads>;
using Fragment = typename Base::Fragment;
private:
@ -747,7 +748,7 @@ class MmaTensorOpMultiplicandTileIterator<
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads>;
using Fragment = typename Base::Fragment;
private:
@ -1023,7 +1024,8 @@ class MmaTensorOpMultiplicandTileIterator<
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads>;
using Fragment = Array<Element, Shape::kStrided *
InstructionShape::kContiguous / kThreads>;
private:
@ -1151,7 +1153,8 @@ class MmaTensorOpMultiplicandTileIterator<
int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile;
byte_offset_ ^= k_groups_delta * sizeof_bits<Element>::value *
Layout::kElementsPerAccess / 8;
Layout::kElementsPerAccess *
Policy::LdsmShape::kContiguous / 8;
pointer_ +=
tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor +
whole_tiles * stride_ / sections_;
@ -1406,7 +1409,7 @@ class MmaTensorOpMultiplicandTileIterator<
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads>;
using Fragment = typename Base::Fragment;
private:
/// Underlying tile iterator
@ -1636,7 +1639,7 @@ class MmaTensorOpMultiplicandTileIterator<
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads>;
using Fragment = typename Base::Fragment;
private:
/// Underlying tile iterator

View File

@ -165,7 +165,8 @@ public:
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
using Fragment = Array<Element, Shape::kContiguous *
InstructionShape::kStrided / kThreads * 2>;
private:
@ -473,7 +474,8 @@ public:
//
/// Fragment object holding a thread's part of a tile, needs on more time number of registers
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
using Fragment = Array<Element, Shape::kContiguous *
InstructionShape::kStrided / kThreads * 2>;
private:
@ -738,7 +740,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
using Fragment = typename Base::Fragment;
private:
@ -962,7 +964,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
using Fragment = typename Base::Fragment;
private:
@ -1557,7 +1559,9 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
using Fragment =
Array<Element,
Shape::kStrided * InstructionShape::kContiguous / kThreads * 2>;
private:
@ -1869,7 +1873,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
using Fragment = typename Base::Fragment;
private:
/// Underlying tile iterator
@ -2097,7 +2101,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator<
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, Shape::kCount / kThreads * 2>;
using Fragment = typename Base::Fragment;
private:
/// Underlying tile iterator

View File

@ -106,6 +106,15 @@ public:
/// Shape of the warp in units of thread (concept: MmaTensorOpPolicy)
using Policy = Policy_;
/// Underlying architecture tag
using ArchTag = typename Policy::Operator::ArchTag;
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
@ -193,7 +202,6 @@ public:
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,244 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/array_planar_complex.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename TileIterator_>
class TileIteratorPlanarComplex {
public:
/// Underlying iterator over real-valued tiles
using TileIterator = TileIterator_;
/// Underlying element type
using Element = typename TileIterator::Element;
/// Underlying layout type
using Layout = typename TileIterator::Layout;
/// TensorRef type for loading element from a tensor
using TensorRef = typename TileIterator::TensorRef;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
/// Planar complex fragment
using Fragment = ArrayPlanarComplex<Element, TileIterator::Fragment::kElements>;
public:
/// Underlying tile iterator
TileIterator tile_iterator_;
/// Offset (in units of bytes) to the imaginary part of the planar complex matrix
LongIndex imaginary_offset_;
public:
/// Default ctor constructs null iterator
CUTLASS_HOST_DEVICE
TileIteratorPlanarComplex(): imaginary_offset_(0) { }
/// Constructor from TensorRef
CUTLASS_DEVICE
TileIteratorPlanarComplex(
TensorRef const &ref,
int lane_id,
LongIndex imaginary_offset
):
tile_iterator_(ref, lane_id),
imaginary_offset_((imaginary_offset * sizeof_bits<Element>::value) / 8) { }
/// Adds a pointer offset to internal pointer(s) to advance through memory
CUTLASS_DEVICE
TileIteratorPlanarComplex &add_pointer_offset(LongIndex offset) {
tile_iterator_.add_pointer_offset(offset);
return *this;
}
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
CUTLASS_HOST_DEVICE
TileIteratorPlanarComplex &add_tile_offset(TensorCoord const &tile_offset) {
tile_iterator_.add_tile_offset(tile_offset);
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_DEVICE
TileIteratorPlanarComplex & operator++() {
++tile_iterator_;
return *this;
}
//
// WIP
//
/// Advances the iterator along the opposite of the advance dimension
CUTLASS_HOST_DEVICE
TileIteratorPlanarComplex & operator--() {
--tile_iterator_;
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_DEVICE
TileIteratorPlanarComplex & operator+=(TensorCoord const &tile_offset) {
tile_iterator_.add_tile_offset(tile_offset);
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_DEVICE
TileIteratorPlanarComplex & operator-=(TensorCoord const &tile_offset) {
tile_iterator_.add_tile_offset(-tile_offset);
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
tile_iterator_.load_with_byte_offset(frag.real, 0);
tile_iterator_.load_with_byte_offset(frag.imag, imaginary_offset_);
}
/// Loads a fragment from memory with additional logical offset
CUTLASS_DEVICE
void load_with_byte_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a linear offset in units of bytes
Index byte_offset) const {
tile_iterator_.load_with_byte_offset(frag.real, byte_offset);
tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_);
}
/// Loads a fragment from memory with additional logical offset
CUTLASS_DEVICE
void load_with_pointer_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a linear offset
Index pointer_offset) const {
Index byte_offset = (pointer_offset * sizeof_bits<Element>::value)/8;
tile_iterator_.load_with_byte_offset(frag.real, byte_offset);
tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_);
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset) const {
tile_iterator_.load_with_byte_offset(frag.real, tile_offset, 0);
tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, imaginary_offset_);
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset,
/// loads a tile with a logical offset AND a pointer offset
Index pointer_offset) const {
Index byte_offset = (pointer_offset * sizeof_bits<Element>::value)/8;
tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset);
tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset + imaginary_offset_);
}
/// Loads a fragment from memory with logical offset in units of whole tiles.
CUTLASS_DEVICE
void load_with_byte_offset(
/// fragment to load from the tensor
Fragment &frag,
/// loads a tile with a logical offset in units of whole tiles
TensorCoord const &tile_offset,
/// loads a tile with a logical offset AND a pointer offset
Index byte_offset) const {
tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset);
tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, byte_offset + imaginary_offset_);
}
/// Notify the iterator which k-group it is currently pointing to.
///
/// This does not advance the iterator. Rather, it overrides its internal
/// tracking with constant-valued k-group index to enable the compiler to
/// fold constants and achieve more efficient code.
///
/// This is used by some nontrivial permuted layouts.
CUTLASS_DEVICE
void set_kgroup_index(int k_group) {
tile_iterator_.set_kgroup_index(k_group);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -78,6 +78,10 @@ enum
#include <immintrin.h>
#if defined(__i386__) || defined(__x86_64__)
#include <intrin.h>
#endif
#define F16C_ROUND_NEAREST 0
#if !defined(__CUDA_ARCH__)
@ -110,9 +114,51 @@ __inline unsigned short _cvtss_sh (float __F, const int) {
// Linux
#include <x86intrin.h>
#if defined(__i386__) || defined(__x86_64__)
#include <cpuid.h>
#endif
#define F16C_ROUND_NEAREST (_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)
#endif
#endif // _MSC_VER
class CpuId {
bool f16c_enabled;
CpuId() {
#if defined(__i386__) || defined(__x86_64__)
#if defined(_MSC_VER)
int exx[4];
__cpuid (exx, 1);
f16c_enabled = exx[2] & 0x20000000;
#else
// GCC / Clang
int eax, ebx, ecx, edx;
__cpuid (1 , eax, ebx, ecx, edx);
f16c_enabled = ecx & 0x20000000;
#endif
#else
// Arm / PowerPC etc.
f16c_enabled = false;
#endif
}
public:
bool is_f16c_supported() const {
return f16c_enabled;
}
static const CpuId& instance() {
static CpuId cpu;
return cpu;
}
};
#endif // !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -154,10 +200,15 @@ struct alignas(2) half_t {
static half_t convert(float const& flt) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
return half_t(__float2half_rn(flt));
#elif !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C
unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST);
return bitcast(u);
#else
#if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C
if( CpuId::instance().is_f16c_supported() ) {
unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST);
return bitcast(u);
}
#endif
// software implementation rounds toward nearest even
unsigned const& s = reinterpret_cast<unsigned const &>(flt);
uint16_t sign = uint16_t((s >> 16) & 0x8000);
@ -248,10 +299,15 @@ struct alignas(2) half_t {
static float convert(half_t const& x) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
return __half2float(x.to_half());
#elif !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C
unsigned short u = x.storage;
return _cvtsh_ss(u);
#else
#if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C
if( CpuId::instance().is_f16c_supported() ) {
unsigned short u = x.storage;
return _cvtsh_ss(u);
}
#endif
uint16_t const &h = x.storage;
int sign = ((h >> 15) & 1);
int exp = ((h >> 10) & 0x1f);

View File

@ -29,7 +29,11 @@
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#endif
#include "cutlass/platform/platform.h"
@ -48,7 +52,7 @@ struct integer_subbyte {
static bool const kSigned = Signed;
/// External type
using T = typename std::conditional<kSigned, int, unsigned>::type;
using T = typename platform::conditional<kSigned, int, unsigned>::type;
/// Storage type
using Storage = uint8_t;

View File

@ -38,7 +38,6 @@
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/tensor_nhwc.h"
#include "cutlass/layout/vector.h"
#include "cutlass/layout/tensor_op_multiplicand_sm70.h"

View File

@ -98,7 +98,7 @@ public:
/// Assumes coordinate has convention (row, column)
CUTLASS_HOST_DEVICE
LongIndex operator()(MatrixCoord const &coord) const {
return coord.row() * stride_[0] + coord.column();
return LongIndex(coord.row()) * LongIndex(stride_[0]) + coord.column();
}
/// Inverse of layout function, mapping linear offset to logical coordinate
@ -134,7 +134,7 @@ public:
/// Compute the number of contiguous elements needed to store a tensor with the given size
CUTLASS_HOST_DEVICE
LongIndex capacity(MatrixCoord const &extent) const {
return extent.row() * stride_[0];
return LongIndex(extent.row()) * LongIndex(stride_[0]);
}
};
@ -191,7 +191,7 @@ public:
/// Assumes coordinate has convention (row, column)
CUTLASS_HOST_DEVICE
LongIndex operator()(MatrixCoord const &coord) const {
return coord.row() + coord.column() * stride_[0];
return LongIndex(coord.column()) * LongIndex(stride_[0]) + coord.row();
}
/// Inverse of layout function, mapping linear offset to logical coordinate
@ -227,7 +227,7 @@ public:
/// Compute the number of contiguous elements needed to store a tensor with the given size
CUTLASS_HOST_DEVICE
LongIndex capacity(MatrixCoord const &extent) const {
return extent.column() * stride_[0];
return LongIndex(extent.column()) * LongIndex(stride_[0]);
}
};
@ -290,7 +290,7 @@ public:
LongIndex operator()(MatrixCoord const &coord) const {
Index row_major = coord.row() / kInterleave;
Index row_minor = coord.row() % kInterleave;
return row_major * stride_[0] + coord.column() * kInterleave + row_minor;
return LongIndex(row_major) * LongIndex(stride_[0]) + LongIndex(coord.column()) * kInterleave + row_minor;
}
/// Inverse of layout function, mapping linear offset to logical coordinate
@ -397,7 +397,7 @@ public:
LongIndex operator()(MatrixCoord const &coord) const {
Index column_major = coord.column() / kInterleave;
Index column_minor = coord.column() % kInterleave;
return column_major * stride_[0] + coord.row() * kInterleave + column_minor;
return LongIndex(column_major) * LongIndex(stride_[0]) + LongIndex(coord.row()) * kInterleave + column_minor;
}
/// Inverse of layout function, mapping linear offset to logical coordinate

View File

@ -116,6 +116,11 @@ public:
return PitchLinearCoord(Base::operator-(b));
}
CUTLASS_HOST_DEVICE
PitchLinearCoord operator-() const {
return PitchLinearCoord(-at(0), -at(1));
}
/// Element-wise multiplication
CUTLASS_HOST_DEVICE
PitchLinearCoord operator*(Base const& b) const {
@ -211,7 +216,7 @@ public:
/// Assumes coordinate has convention (contiguous, strided)
CUTLASS_HOST_DEVICE
LongIndex operator()(TensorCoord const &coord) const {
return coord.contiguous() + coord.strided() * stride_[0];
return LongIndex(coord.contiguous()) + LongIndex(coord.strided()) * LongIndex(stride_[0]);
}
/// Returns the logical coordinate given an offset.

View File

@ -33,7 +33,11 @@
defined in cutlass/tensor_ref.h.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include "assert.h"
#endif
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/layout/matrix.h"

View File

@ -45,7 +45,7 @@ enum class FloatRoundStyle {
round_toward_zero, ///< round toward zero
round_to_nearest, ///< round to nearest even
round_toward_infinity, ///< round toward infinity
round_toward_neg_infinity, ///< round toward negative infinity
round_toward_neg_infinity ///< round toward negative infinity
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -268,8 +268,7 @@ struct NumericConverterClamp {
result_type const kClamp_min = -kClamp_max - 1;
bool is_int_min = !(s > kClamp_min);
bool is_int_max = !(s < kClamp_max);
return (is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s)));
return is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s));
}
CUTLASS_HOST_DEVICE
@ -295,15 +294,15 @@ struct NumericConverterClamp<T, float> {
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & s) {
NumericConverter<result_type, source_type> convert_op;
NumericConverter<result_type, double> convert_op;
float kClamp_max = float((1 << (sizeof_bits<result_type>::value - 1)) - 1);
float kClamp_min = -kClamp_max - 1;
double kClamp_max = double((1U << (sizeof_bits<result_type>::value - 1)) - 1);
double kClamp_min = -kClamp_max - 1;
float source = s;
double source = s;
source = fmaxf(source, kClamp_min);
source = fminf(source, kClamp_max);
source = fmax(source, kClamp_min);
source = fmin(source, kClamp_max);
return convert_op(source);
}
@ -353,6 +352,24 @@ struct NumericArrayConverter {
}
};
template <
typename T,
int N,
FloatRoundStyle Round
>
struct NumericArrayConverter<T, T, N, Round> {
using result_type = Array<T, N>;
using source_type = Array<T, N>;
static FloatRoundStyle const round_style = Round;
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Array<half, 2> <= Array<float, 2>, round to nearest
@ -498,10 +515,15 @@ struct NumericArrayConverter<float, half_t, N, Round> {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conditional guards to enable partial specialization for packed integers
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
/// Partial specialization for Array<int8_t, 1> <= Array<int, 1>
template <
@ -625,12 +647,13 @@ struct NumericArrayConverter<int8_t, int, N, Round> {
return convert(s);
}
};
#endif // Conditional guards to enable partial specialization for packed integers
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \
((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
/// Partial specialization for Array<int4b_t, 8> <= Array<int, 8>
template <
@ -707,4 +730,121 @@ struct NumericArrayConverter<int4b_t, int, N, Round> {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// FastNumericArrayConverter only works when the source is within center range.
/// Conversion operator for Array
template <typename T, typename S, int N,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
struct FastNumericArrayConverter {
using result_type = Array<T, N>;
using source_type = Array<S, N>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &s) {
result_type result;
NumericArrayConverter<T, S, N, Round> convert_;
return convert_(s);
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) { return convert(s); }
};
/// Partial specialization for Array<float> <= Array<int>
template <typename T, int N, FloatRoundStyle Round>
struct FastNumericArrayConverter<float, T, N, Round> {
using result_type = Array<float, N>;
using source_type = Array<T, N>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
int tmp = source[i] + 1262485504 /*0x4B400000*/;
result[i] = reinterpret_cast<float const &>(tmp) - 12582912.0f;
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) { return convert(s); }
};
/// Partial specialization for Array<int8_t, 4> <= Array<float, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<int8_t, float, 4, Round> {
using result_type = Array<int8_t, 4>;
using source_type = Array<float, 4>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
Array<int32_t, 4> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
float tmp = source[i] + 12582912.0f;
result[i] = reinterpret_cast<int32_t const &>(tmp);
}
result[0] = __byte_perm(result[0], result[1], 0x40);
result[2] = __byte_perm(result[2], result[3], 0x40);
result[0] = __byte_perm(result[0], result[2], 0x5410);
return reinterpret_cast<result_type const &>(result[0]);
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) { return convert(s); }
};
/// Partial specialization for Array<int8_t> <= Array<float>
template <int N, FloatRoundStyle Round>
struct FastNumericArrayConverter<int8_t, float, N, Round> {
static_assert(!(N % 4), "N must be multiple of 4.");
using result_type = Array<int8_t, N>;
using source_type = Array<float, N>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
FastNumericArrayConverter<int8_t, float, 4, Round> convert_vector_;
result_type result;
Array<int8_t, 4> *result_ptr =
reinterpret_cast<Array<int8_t, 4> *>(&result);
Array<float, 4> const *source_ptr =
reinterpret_cast<Array<float, 4> const *>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) { return convert(s); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines preferred rounding mode for a pair of types
template <typename T, typename S>
struct PreferredRoundingMode {
static FloatRoundStyle const kRound = FloatRoundStyle::round_to_nearest;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -28,7 +28,11 @@
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/cutlass.h"

View File

@ -95,7 +95,11 @@
// Dependencies
//-----------------------------------------------------------------------------
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <stdint.h>
#endif
#if !defined(__CUDACC_RTC__)
//-----------------------------------------------------------------------------

View File

@ -28,10 +28,13 @@
*/
#pragma once
#if !defined(__CUDACC_RTC__)
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#include <cuda/std/cstdint>
#else
#include <assert.h>
#endif
#include <stdint.h>
#endif
#include "cutlass/cutlass.h"

View File

@ -30,6 +30,10 @@ namespace cutlass {
template <typename T>
struct RealType {
using Type = T;
static T from_real(double x) {
return static_cast<T>(x);
}
};
template <typename T>
@ -38,4 +42,5 @@ static T from_real(double r) {
return T(r);
}
} // namespace cutlass

View File

@ -85,7 +85,13 @@ struct ReduceAdd {
plus<FragmentAccumulator> op;
return op(accumulator, element);
NumericArrayConverter<
ElementAccumulator,
Element,
kCount,
PreferredRoundingMode<ElementAccumulator, Element>::kRound> converter;
return op(accumulator, converter(element));
}
};

View File

@ -66,8 +66,9 @@ public:
/// Permit fetching the synchronization mechanism early
CUTLASS_DEVICE
void fetch() {
asm volatile ("ld.global.cg.s32 %0, [%1];\n" : "=r"(state) : "l"(lock));
if (wait_thread) {
asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
}
}
/// Gets the internal state
@ -80,14 +81,8 @@ public:
CUTLASS_DEVICE
void wait(int status = 0) {
if (wait_thread) {
while (state != status) {
fetch();
__syncwarp(0x01);
};
while( __syncthreads_and(state != status) ) {
fetch();
}
__syncthreads();
@ -99,8 +94,7 @@ public:
__syncthreads();
if (wait_thread) {
asm volatile ("st.global.cg.s32 [%0], %1;\n" : : "l"(lock), "r"(status));
asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
}
}
};

View File

@ -379,6 +379,7 @@ template <
typename Element,
typename Layout
>
CUTLASS_HOST_DEVICE
bool TensorRef_aligned(TensorRef<Element, Layout> const &ref, int alignment) {
int const kStrideRank = Layout::kStrideRank;

View File

@ -0,0 +1,368 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a structure containing strides, bounds, and a pointer to tensor data.
*/
#pragma once
#include <cstdint>
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/tensor_ref.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Element_>
struct PlanarComplexReference {
//
// Type definitions
//
using Element = Element_;
using ComplexElement = complex<Element>;
//
// Data members
//
Element *real;
Element *imag;
//
// Methods
//
CUTLASS_HOST_DEVICE
PlanarComplexReference(
Element *real_ = nullptr,
Element *imag_ = nullptr
):
real(real_), imag(imag_) { }
/// Loads the complex element
CUTLASS_HOST_DEVICE
operator complex<Element>() const {
return complex<Element>{*real, *imag};
}
/// Stores a complex element to the location pointed to by the reference
CUTLASS_HOST_DEVICE
PlanarComplexReference &operator=(complex<Element> const &rhs) {
*real = rhs.real();
*imag = rhs.imag();
return *this;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank
and layout within memory. A TensorRef combines a pointer and a Layout concept
*/
template <
/// Data type of element stored within tensor (concept: NumericType)
typename Element_,
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
typename Layout_
>
class TensorRefPlanarComplex {
public:
/// Data type of individual access
using Element = Element_;
/// Complex element type
using ComplexElement = complex<Element>;
/// Mapping function from logical coordinate to linear memory
using Layout = Layout_;
static_assert(sizeof_bits<Element>::value >= 8,
"Planar complex not suitable for subbyte elements at this time");
/// Reference type to an element
using Reference = PlanarComplexReference<Element>;
/// Logical rank of tensor index space
static int const kRank = Layout::kRank;
/// Index type
using Index = typename Layout::Index;
/// Long index used for pointer offsets
using LongIndex = typename Layout::LongIndex;
/// Coordinate in logical tensor space
using TensorCoord = typename Layout::TensorCoord;
/// Layout's stride vector
using Stride = typename Layout::Stride;
/// TensorRef to constant data
using ConstTensorRef = TensorRefPlanarComplex<
typename platform::remove_const<Element>::type const,
Layout>;
/// TensorRef to non-constant data
using NonConstTensorRef = TensorRefPlanarComplex<
typename platform::remove_const<Element>::type,
Layout>;
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
/// scalar, but degenerate cases such as these are difficult to accommodate without
/// extensive C++ metaprogramming or support for zero-length arrays.
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
private:
/// Pointer
Element* ptr_;
/// Layout object maps logical coordinates to linear offsets
Layout layout_;
/// Offset to imaginary part
LongIndex imaginary_stride_;
public:
//
// Methods
//
/// Constructs a TensorRef with a pointer and layout object.
CUTLASS_HOST_DEVICE
TensorRefPlanarComplex(
Element *ptr = nullptr, ///< pointer to start of tensor
Layout const &layout = Layout(), ///< layout object containing stride and mapping function
LongIndex imaginary_stride = 0
):
ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) {
}
/// Converting constructor from TensorRef to non-constant data.
CUTLASS_HOST_DEVICE
TensorRefPlanarComplex(
NonConstTensorRef const &ref ///< TensorRef to non-const data
):
ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { }
/// Returns a reference to constant-valued tensor.
CUTLASS_HOST_DEVICE
ConstTensorRef const_ref() const {
return ConstTensorRef(ptr_, layout_, imaginary_stride_);
}
CUTLASS_HOST_DEVICE
NonConstTensorRef non_const_ref() const {
return NonConstTensorRef(
const_cast<typename platform::remove_const<Element>::type *>(ptr_),
layout_,
imaginary_stride_);
}
/// Updates only the pointer
CUTLASS_HOST_DEVICE
void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) {
ptr_ = ptr;
imaginary_stride_ = imaginary_stride;
}
/// Updates the pointer and layout object
CUTLASS_HOST_DEVICE
void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) {
ptr_ = ptr;
layout_ = layout;
imaginary_stride_ = imaginary_stride;
}
/// Returns true if the TensorRef is non-null
CUTLASS_HOST_DEVICE
bool good() const {
return ptr_ != nullptr;
}
/// Returns the pointer to referenced data
CUTLASS_HOST_DEVICE
Element * data() const { return ptr_; }
/// Returns the pointer to referenced data
CUTLASS_HOST_DEVICE
Element * imaginary_data() const { return ptr_ + imaginary_stride_; }
/// Returns a reference to the element at a given linear index
CUTLASS_HOST_DEVICE
Reference data(LongIndex idx) const {
return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_);
}
/// Returns the layout object
CUTLASS_HOST_DEVICE
Layout & layout() {
return layout_;
}
/// Returns the layout object
CUTLASS_HOST_DEVICE
Layout layout() const {
return layout_;
}
/// Gets the stride to an imaginary element
LongIndex imaginary_stride() const {
return imaginary_stride_;
}
/// Gets the stride to an imaginary element
LongIndex &imaginary_stride() {
return imaginary_stride_;
}
/// Returns the layout object's stride vector
CUTLASS_HOST_DEVICE
Stride stride() const {
return layout_.stride();
}
/// Returns the layout object's stride vector
CUTLASS_HOST_DEVICE
Stride & stride() {
return layout_.stride();
}
/// Returns the layout object's stride in a given physical dimension
CUTLASS_HOST_DEVICE
Index stride(int dim) const {
return layout_.stride().at(dim);
}
/// Returns the layout object's stride in a given physical dimension
CUTLASS_HOST_DEVICE
Index & stride(int dim) {
return layout_.stride().at(dim);
}
/// Computes the offset of an index from the origin of the tensor
CUTLASS_HOST_DEVICE
LongIndex offset(TensorCoord const& coord) const {
return layout_(coord);
}
/// Returns a reference to the element at a given Coord
CUTLASS_HOST_DEVICE
Reference at(TensorCoord const& coord) const {
return data(offset(coord));
}
/// Returns a reference to the element at a given Coord
CUTLASS_HOST_DEVICE
Reference operator[](TensorCoord const& coord) const {
return data(offset(coord));
}
/// Adds an offset to each pointer
CUTLASS_HOST_DEVICE
TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) {
ptr_ += offset_;
return *this;
}
/// Adds an offset to each pointer
CUTLASS_HOST_DEVICE
TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) {
add_pointer_offset(offset(coord));
return *this;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRefPlanarComplex operator+(TensorCoord const& b) const {
TensorRefPlanarComplex result(*this);
result.add_coord_offset(b);
return result;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRefPlanarComplex & operator+=(TensorCoord const& b) {
add_coord_offset(b);
return *this;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRefPlanarComplex operator-(TensorCoord const& b) const {
TensorRefPlanarComplex result(*this);
result.add_pointer_offset(-offset(b));
return result;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRefPlanarComplex & operator-=(TensorCoord const& b) {
add_pointer_offset(-offset(b));
return *this;
}
/// TensorRef to real-valued tensor
CUTLASS_HOST_DEVICE
cutlass::TensorRef<Element, Layout> ref_real() const {
return cutlass::TensorRef<Element, Layout>(data(), layout());
}
/// TensorRef to real-valued tensor
CUTLASS_HOST_DEVICE
cutlass::TensorRef<Element, Layout> ref_imag() const {
return cutlass::TensorRef<Element, Layout>(imaginary_data(), layout());
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Constructs a TensorRef, deducing types from arguments.
template <
typename Element,
typename Layout
>
CUTLASS_HOST_DEVICE
TensorRefPlanarComplex<Element, Layout> make_TensorRefPlanarComplex(
Element *ptr,
Layout const &layout,
int64_t imaginary_stride) {
return TensorRefPlanarComplex<Element, Layout>(ptr, layout, imaginary_stride);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,293 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a structure containing strides and a pointer to tensor data.
TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus,
it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from
data storage and is therefore lightweight and may be embedded in larger tensor objects or
memory structures.
See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to
linear memory.
*/
#pragma once
#if !defined(__CUDACC_RTC__)
#include <cmath>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/tensor_ref_planar_complex.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Data type of element stored within tensor
typename Element_,
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
typename Layout_
>
class TensorViewPlanarComplex : public TensorRefPlanarComplex<Element_, Layout_> {
public:
/// Base tensor reference
using Base = cutlass::TensorRefPlanarComplex<Element_, Layout_>;
/// Mapping function from logical coordinate to internal n-D array
using Layout = Layout_;
/// TensorRef pointing to constant memory
using ConstTensorRef = typename Base::ConstTensorRef;
/// Underlying TensorRef type
using TensorRef = Base;
/// Data type of individual access
using Element = Element_;
/// Reference type to an element
using Reference = Element &;
/// Logical rank of tensor index space
static int const kRank = Layout::kRank;
/// Index type
using Index = typename Layout::Index;
/// Long index used for pointer offsets
using LongIndex = typename Layout::LongIndex;
/// Coordinate in logical tensor space
using TensorCoord = typename Layout::TensorCoord;
/// Coordinate in storage n-D array
using Stride = typename Layout::Stride;
/// TensorView pointing to constant memory
using ConstTensorView = TensorViewPlanarComplex<
typename platform::remove_const<Element>::type const,
Layout>;
/// TensorView pointing to non-constant memory
using NonConstTensorView = TensorViewPlanarComplex<
typename platform::remove_const<Element>::type,
Layout>;
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
/// scalar, but degenerate cases such as these are difficult to accommodate without
/// extensive C++ metaprogramming or support for zero-length arrays.
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
private:
/// View extent
TensorCoord extent_;
public:
//
// Methods
//
/// Constructs a TensorView object
CUTLASS_HOST_DEVICE
TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) {
}
/// Constructs a TensorView object
CUTLASS_HOST_DEVICE
TensorViewPlanarComplex(
Element *ptr, ///< pointer to start of tensor
Layout const &layout, ///< layout object containing stride and mapping function
LongIndex imaginary_stride, ///< stride between real and imaginary part
TensorCoord const &extent ///< size of the view in logical coordinates
):
Base(ptr, layout, imaginary_stride), extent_(extent) {
}
/// Constructs a TensorView object
CUTLASS_HOST_DEVICE
TensorViewPlanarComplex(
TensorRef const &ref, ///< pointer and layout object referencing a tensor
TensorCoord const &extent ///< logical size of tensor
):
Base(ref), extent_(extent) {
}
/// Converting constructor from TensorRef to non-constant data.
CUTLASS_HOST_DEVICE
TensorViewPlanarComplex(
NonConstTensorView const &view ///< TensorView to non-const data
):
Base(view), extent_(view.extent_) { }
/// Updates the pointer and layout object
CUTLASS_HOST_DEVICE
void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) {
Base::reset(ptr, layout, imaginary_stride);
this->resize(extent_);
}
/// Changes the size of the view without affecting pointer or layout
CUTLASS_HOST_DEVICE
void resize(TensorCoord extent) {
this->extent_ = extent;
}
/// Returns the extent of the view (the size along each logical dimension).
CUTLASS_HOST_DEVICE
TensorCoord const& extent() const { return extent_; }
/// Returns the extent along a particular logical dimension.
CUTLASS_HOST_DEVICE
Index extent(int dim) const { return extent_.at(dim); }
/// Determines whether a location is within a tensor
CUTLASS_HOST_DEVICE
bool contains(TensorCoord const& coord) const {
CUTLASS_PRAGMA_UNROLL
for (int dim = 0; dim < kRank; ++dim) {
if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) {
return false;
}
}
return true;
}
/// Returns a TensorRef pointing to the first element of the tensor.
CUTLASS_HOST_DEVICE
Base ref() const {
return Base(this->data(), this->layout(), this->imaginary_stride());
}
/// Returns a TensorRef pointing to the first element of the tensor.
CUTLASS_HOST_DEVICE
ConstTensorRef const_ref() const {
return ConstTensorRef(this->data(), this->layout());
}
/// Returns a TensorView to const data
CUTLASS_HOST_DEVICE
ConstTensorView const_view() const {
return ConstTensorView(const_ref(), extent_);
}
/// Returns a Tensor_view given location and size quantities
CUTLASS_HOST_DEVICE
TensorViewPlanarComplex subview(
TensorCoord extent, ///< extent of the resulting view
TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view
) const {
return TensorViewPlanarComplex(ref(), extent.clamp(extent_ - location)).add_coord_offset(location);
}
/// Returns the number of scalar elements needed to store tensor.
CUTLASS_HOST_DEVICE
size_t capacity() const {
return Base::layout().capacity(extent_);
}
/// Returns a TensorView offset by a given amount
CUTLASS_HOST_DEVICE
TensorViewPlanarComplex operator+(
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
) const {
TensorViewPlanarComplex result(*this);
result.add_pointer_offset(this->offset(b));
return result;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorViewPlanarComplex& operator+=(
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
) {
this->add_pointer_offset(this->offset(b));
return *this;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorViewPlanarComplex operator-(
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
) const {
TensorRef result(*this);
result.add_pointer_offset(-this->offset(b));
return result;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorViewPlanarComplex& operator-=(
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
) {
this->add_pointer_offset(-this->offset(b));
return *this;
}
/// TensorRef to real-valued tensor
CUTLASS_HOST_DEVICE
cutlass::TensorView<Element, Layout> view_real() const {
return cutlass::TensorView<Element, Layout>(this->data(), this->layout(), extent_);
}
/// TensorRef to real-valued tensor
CUTLASS_HOST_DEVICE
cutlass::TensorView<Element, Layout> view_imag() const {
return cutlass::TensorView<Element, Layout>(this->imaginary_data(), this->layout(), extent_);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Constructs a TensorRef, deducing types from arguments.
template <
typename Element,
typename Layout
>
CUTLASS_HOST_DEVICE TensorViewPlanarComplex<Element, Layout> make_TensorViewPlanarComplex(
Element *ptr,
Layout const &layout,
typename Layout::LongIndex imaginary_stride,
typename Layout::TensorCoord const &extent) {
return TensorViewPlanarComplex<Element, Layout>(ptr, layout, imaginary_stride, extent);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -128,13 +128,13 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
int stride_;
/// amount (in byte) to increment pointer to move to next access along
/// strided dimension
int inc_strided_;
LongIndex inc_strided_;
/// amount (in byte) to increment pointer from last access to first access
/// of next tile
int inc_next_;
LongIndex inc_next_;
/// amount (in byte) to increment pointer from first access of current tile
/// to first access of next tile
int inc_advance_;
LongIndex inc_advance_;
public:
@ -145,20 +145,20 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const &layout) : stride_(layout.stride(0)) {
inc_strided_ = (stride_ * ThreadMap::Delta::kStrided) *
inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) *
sizeof_bits<Element>::value / 8;
if (kAdvanceRank) {
// advance along strided dimension
inc_advance_ =
Shape::kStrided * stride_ * sizeof_bits<Element>::value / 8;
Shape::kStrided * LongIndex(stride_) * sizeof_bits<Element>::value / 8;
} else {
// advance along contiguous dimension
inc_advance_ = Shape::kContiguous * sizeof_bits<Element>::value / 8;
}
inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
ThreadMap::Delta::kStrided * stride_ *
inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) *
ThreadMap::Delta::kStrided * LongIndex(stride_) *
sizeof_bits<Element>::value / 8;
};
};
@ -280,7 +280,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
TensorCoord residue_extent;
if (kAdvanceRank) {
Index residue_size = (extent_[kAdvanceRank] % Shape::kStrided);
Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided;
if (!residue_size) {
residue_size = Shape::kStrided;
}
@ -288,18 +288,19 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
residue_offset_ = make_Coord(0, residue_size);
residue_extent = make_Coord(
extent_.contiguous(),
min(threadblock_offset.strided() + residue_offset_.strided(), extent_.strided())
min(threadblock_offset.strided() + residue_size, extent_.strided())
);
} else {
Index residue_size = (extent_[kAdvanceRank] % Shape::kContiguous);
Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous;
if (!residue_size) {
residue_size = Shape::kContiguous;
}
residue_offset_ = make_Coord(residue_size, 0);
residue_extent = make_Coord(
min(extent_.contiguous(), threadblock_offset.contiguous() + residue_offset_.contiguous()),
min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size),
extent_.strided()
);
}
@ -362,18 +363,18 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
compute_predicates_(extent_, true);
if (kAdvanceRank) {
pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1);
pointer_ += Shape::kContiguous * tile_offset.contiguous();
} else {
pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
pointer_ += Shape::kStrided * tile_offset.strided();
}
} else {
if (kAdvanceRank) {
pointer_ += params_.inc_advance_ * tile_offset.strided();
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
pointer_ += Shape::kContiguous * tile_offset.contiguous();
} else {
pointer_ += params_.inc_advance_ * tile_offset.contiguous();
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
pointer_ += Shape::kStrided * tile_offset.strided();
}
}

View File

@ -296,7 +296,12 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
}
CUTLASS_DEVICE
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
CUTLASS_PRAGMA_UNROLL
@ -310,10 +315,12 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
address_iterator_.set_iteration_index(idx);
auto ptr = (address_iterator_.get() + pointer_offset);
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
if (address_iterator_.valid()) {
frag_ptr[idx] = *ptr;
frag_ptr[idx] = *access_ptr;
}
++address_iterator_;
}
@ -323,11 +330,17 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
address_iterator_.set_iteration_index(0);
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
@ -340,8 +353,11 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
if (address_iterator_.valid()) {
*(address_iterator_.get() + pointer_offset) = frag_ptr[idx];
*access_ptr = frag_ptr[idx];
}
++address_iterator_;
}
@ -351,7 +367,7 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
};
////////////////////////////////////////////////////////////////////////////////
@ -528,6 +544,12 @@ public:
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
iterator_.load_with_byte_offset(frag, byte_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) {
@ -540,6 +562,12 @@ public:
iterator_.store_with_pointer_offset(frag, pointer_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
iterator_.store_with_byte_offset(frag, byte_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) {
@ -721,6 +749,12 @@ public:
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
iterator_.load_with_byte_offset(frag, byte_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) {
@ -732,6 +766,12 @@ public:
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
iterator_.store_with_pointer_offset(frag, pointer_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
iterator_.store_with_byte_offset(frag, byte_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE

Some files were not shown because too many files have changed in this diff Show More