diff --git a/CHANGELOG.md b/CHANGELOG.md index d7ce23a3..367d6935 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ # NVIDIA CUTLASS Changelog -# CUTLASS 2.0 +# CUTLASS 2.x + +## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06) + * BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library) + * API to launch compiled kernel instances for GEMM and planar complex GEMM + * Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores + * Computes complex matrix products on matrices stored as disjoint real and imaginary parts + * [SDK Examples of Planar Complex GEMMs](/examples/10_planar_complex/planar_complex.cu) + * Minor enhancements and bug fixes ## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19) * Substantially refactored for @@ -22,7 +30,7 @@ * Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions * Batched GEMM operations * Complex-valued GEMMs - * Note: a host compiler supporting C++11 or greater is required. + * **Note: a host compiler supporting C++11 or greater is required.** # CUTLASS 1.x @@ -76,7 +84,7 @@ ## Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/CMakeLists.txt b/CMakeLists.txt old mode 100644 new mode 100755 index 85d39593..1b7bbc48 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,15 +32,14 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") -project(CUTLASS VERSION 2.0.0 LANGUAGES CXX) +project(CUTLASS VERSION 2.1.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) find_package(Doxygen QUIET) # -# CUTLASS 2.0 requires C++11 +# CUTLASS 2.x requires C++11 # - set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) @@ -49,7 +48,7 @@ if(CUTLASS_NATIVE_CUDA) set(CMAKE_CUDA_STANDARD 11) set(CMAKE_CUDA_STANDARD_REQUIRED ON) else() - string(APPEND NVCC_FLAGS " --std=c++11") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++11) endif() if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) @@ -58,13 +57,26 @@ endif() message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}") -if(${CMAKE_PROJECT_NAME} MATCHES ${PROJECT_NAME}) - set(_CUTLASS_ENABLE_TESTS ON) +set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library") + +if(CUTLASS_ENABLE_HEADERS_ONLY) + set(CUTLASS_ENABLE_EXAMPLES_INIT OFF) + set(CUTLASS_ENABLE_TOOLS_INIT OFF) else() - set(_CUTLASS_ENABLE_TESTS OFF) + set(CUTLASS_ENABLE_EXAMPLES_INIT ON) + set(CUTLASS_ENABLE_TOOLS_INIT ON) endif() -set(CUTLASS_ENABLE_TESTS ${_CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS Tests") +set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples") +set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools") + +if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) + set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT}) +else() + set(CUTLASS_ENABLE_TESTS_INIT OFF) +endif() + +set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests") if (CUTLASS_ENABLE_TESTS) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake) @@ -86,31 +98,22 @@ endif() if (NOT CUDA_VERSION VERSION_LESS 10.0) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75) endif() - -if(CUDA_COMPILER MATCHES "[Cc]lang") - if(NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" ) - message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" ) - endif() - if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) - message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") - endif() -endif() - set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.") # Special policy introduced in CMake 3.13 if (POLICY CMP0076) cmake_policy(SET CMP0076 NEW) -endif() +endif() -# check if the configuration is supported -if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) +if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 ) message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!") endif() include(GNUInstallDirs) +link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) + ################################################################################################### # # Configure CMake variables @@ -120,11 +123,14 @@ include(GNUInstallDirs) message(STATUS "CUDA Compilation Architectures: ${CUTLASS_NVCC_ARCHS_ENABLED}") if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES)) - # By default we want to build in Release mode to ensure that we're getting best performance. + # By default we want to build in Release mode to ensure that we're getting best performance. set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "RelWithDebInfo" "Release") endif() +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CUTLASS_LIBRARY_DEBUG_POSTFIX ".debug" CACHE STRING "Default postfix value for debug libraries") + if(WIN32) # On Windows we link against the shared (DLL) runtime. Change gtest settings to match this. set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE) @@ -132,29 +138,35 @@ endif() if (WIN32) # Enable more warnings and treat as errors - string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX) # Disable warning on Unicode characters - string(APPEND NVCC_FLAGS " -Xcompiler /wd4819") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819) # Disable excess x86 floating point precision that can lead to results being labeled incorrectly - string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/fp:strict) endif(WIN32) if (${CUTLASS_NVCC_VERBOSE}) - string(APPEND NVCC_FLAGS " -v") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v) endif() set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.") set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") set(CUTLASS_ENABLE_F16C ON CACHE BOOL "Enable F16C x86 extensions in host code.") -set(CUTLASS_LIBRARY_KERNELS "128x128" CACHE STRING "Comma delimited list of kernel name filters. Default '' means all kernels are enabled.") + +# +# CUTLASS generator cmake configuration +# +set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.") +set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.") + # Test Levels L0, L1, L2 set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.") set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2) -string(APPEND NVCC_FLAGS " -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL}") +list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL}) # # CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations. @@ -166,7 +178,7 @@ else() set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON) endif() -set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL +set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL "Enable PTX mma instruction for collective matrix multiply operations.") # @@ -182,7 +194,7 @@ set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CAC # ... # if(ENABLE_ASAN) # https://github.com/google/sanitizers/wiki/AddressSanitizer - string(APPEND NVCC_FLAGS " --compiler-options -fsanitize=address --compiler-options -fno-omit-frame-pointer") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS --compiler-options=-fsanitize=address --compiler-options=-fno-omit-frame-pointer) string(APPEND CMAKE_EXE_LINKER_FLAGS " -fsanitize=address") endif() @@ -192,85 +204,127 @@ endif() # ################################################################################################### -foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED}) - if(CUTLASS_NVCC_EMBED_CUBIN) - string(APPEND NVCC_GENCODE_FLAGS " -gencode=arch=compute_${ARCH},code=sm_${ARCH}") - endif() - if(CUTLASS_NVCC_EMBED_PTX) - string(APPEND NVCC_GENCODE_FLAGS " -gencode=arch=compute_${ARCH},code=compute_${ARCH}") - endif() - string(APPEND CLANG_FLAGS " --cuda-gpu-arch=sm_${ARCH}") -endforeach() - if(CUTLASS_NVCC_EMBED_PTX) - string(APPEND CLANG_FLAGS " --cuda-include-ptx=all") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-include-ptx=all) endif() if (CUTLASS_ENABLE_TENSOR_CORE_MMA) - string(APPEND COMMON_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1") + list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) endif() if (NOT MSVC AND CUTLASS_NVCC_KEEP) # MSVC flow handles caching already, but for other generators we handle it here. set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files") file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR}) - string(APPEND NVCC_FLAGS " --keep") # --keep-dir may not work with nvcc for some directories. - string(APPEND CLANG_FLAGS " -save-temps=${CUTLASS_NVCC_KEEP_DIR}") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep) # --keep-dir may not work with nvcc for some directories. + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR}) endif() -if (CUTLASS_ENABLE_F16C) - string(APPEND COMPILER_FLAGS " -DCUTLASS_ENABLE_F16C=1") +if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING) + list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_F16C=1) if ((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) - string(APPEND NVCC_FLAGS " -Xcompiler -mf16c") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-mf16c) elseif((CMAKE_CXX_COMPILER_ID MATCHES "MSVC")) - string(APPEND NVCC_FLAGS " -Xcompiler /arch:AVX2") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/arch:AVX2) endif() endif() -string(APPEND NVCC_FLAGS " -lineinfo") +list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$:-Xcompiler=-Wconversion>) +list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$:-Xcompiler=-fno-strict-aliasing>) -string(APPEND CLANG_FLAGS " -gmlt") - -if (UNIX) - string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion") - string(APPEND NVCC_FLAGS " -Xcompiler -fno-strict-aliasing") +# Don't leak lineinfo in release builds +if (NOT CMAKE_BUILD_TYPE MATCHES "Release") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -gmlt) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo) endif() if(CUDA_COMPILER MATCHES "[Cc]lang") - string(APPEND CLANG_FLAGS " --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}") - string(APPEND CLANG_FLAGS " -mllvm -pragma-unroll-threshold=100000") - string(APPEND CLANG_FLAGS " -mllvm -unroll-threshold=5000") - string(APPEND CLANG_FLAGS " -Wno-unused-command-line-argument") + if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" ) + message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" ) + endif() + + if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) + message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") + endif() + + list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}) + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-pragma-unroll-threshold=100000) + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-unroll-threshold=5000) + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument) string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION}) list(GET CUDA_VERSION_PARTS 0 CUDA_VERSION_MAJOR) list(GET CUDA_VERSION_PARTS 1 CUDA_VERSION_MINOR) - string(APPEND CLANG_FLAGS " -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR}") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR}) + # needed for libcublasLt.so in case it's installed in the same location as libcudart.so # dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags) # Otherwise linker uses RUNPATH and that does not propagate to loaded libs. - string(APPEND CLANG_FLAGS " -Wl,--disable-new-dtags") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wl,--disable-new-dtags) link_libraries(nvidia::cudart) endif() -if(CUDA_COMPILER MATCHES "[Cc]lang") - string(APPEND CMAKE_CXX_FLAGS "${COMMON_FLAGS} ${CLANG_FLAGS}") - string(APPEND CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${CLANG_FLAGS_RELEASE}") - string(APPEND CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${CLANG_FLAGS_RELWITHDEBINFO}") - string(APPEND CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${CLANG_FLAGS_DEBUG}") -elseif (CUTLASS_NATIVE_CUDA) - string(APPEND CMAKE_CUDA_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS} ${NVCC_GENCODE_FLAGS}") - string(APPEND CMAKE_CUDA_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${NVCC_FLAGS_RELEASE}") - string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${NVCC_FLAGS_RELWITHDEBINFO}") - string(APPEND CMAKE_CUDA_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${NVCC_FLAGS_DEBUG}") -else() - string(APPEND CUDA_NVCC_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS} ${NVCC_GENCODE_FLAGS}") - string(APPEND CUDA_NVCC_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE} ${NVCC_FLAGS_RELEASE}") - string(APPEND CUDA_NVCC_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO} ${NVCC_FLAGS_RELWITHDEBINFO}") - string(APPEND CUDA_NVCC_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG} ${NVCC_FLAGS_DEBUG}") -endif() +function(cutlass_apply_cuda_gencode_flags TARGET) + + set(NVCC_FLAGS) + set(CLANG_FLAGS) + foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED}) + list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH}) + set(CODES) + if(CUTLASS_NVCC_EMBED_CUBIN) + list(APPEND CODES sm_${ARCH}) + endif() + if(CUTLASS_NVCC_EMBED_PTX) + list(APPEND CODES compute_${ARCH}) + endif() + list(JOIN CODES "," CODES_STR) + list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}]) + endforeach() + + if (CUDA_COMPILER MATCHES "[Cc]lang") + target_compile_options( + ${TARGET} + PRIVATE + $<$:${CLANG_FLAGS}> + ) + else() + target_compile_options( + ${TARGET} + PRIVATE + $<$:${NVCC_FLAGS}> + ) + endif() + +endfunction() + +function(cutlass_apply_standard_compile_options TARGET) + + if(CUDA_COMPILER MATCHES "[Cc]lang") + set(CUDA_COMPILE_LANGUAGE CXX) + set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_CLANG_FLAGS}) + set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_CLANG_FLAGS_RELEASE}) + set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO}) + set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_CLANG_FLAGS_DEBUG}) + else() + set(CUDA_COMPILE_LANGUAGE CUDA) + set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_NVCC_FLAGS}) + set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_NVCC_FLAGS_RELEASE}) + set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_NVCC_FLAGS_RELWITHDEBINFO}) + set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG}) + endif() + + target_compile_options( + ${TARGET} + PRIVATE + $<$:${_FLAGS}> + $<$:$<$:${_FLAGS_RELEASE}>> + $<$:$<$:${_FLAGS_RELWITHDEBINFO}>> + $<$:$<$:${_FLAGS_DEBUG}>> + ) + +endfunction() # # The following items should eventually be pushed into cutlass/CMakeLists.txt @@ -324,8 +378,8 @@ if (NOT DEFINED CUTLASS_REVISION) endif() configure_file( - ${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in - ${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in + ${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h @ONLY) target_include_directories( @@ -338,8 +392,8 @@ target_include_directories( ) install( - DIRECTORY - ${CUTLASS_INCLUDE_DIR}/ + DIRECTORY + ${CUTLASS_INCLUDE_DIR}/ ${CMAKE_CURRENT_BINARY_DIR}/include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) @@ -399,27 +453,6 @@ endif() ################################################################################ -set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library") - -if(CUTLASS_ENABLE_HEADERS_ONLY) - set(CUTLASS_ENABLE_EXAMPLES_INIT OFF) - set(CUTLASS_ENABLE_TOOLS_INIT OFF) -else() - set(CUTLASS_ENABLE_EXAMPLES_INIT ON) - set(CUTLASS_ENABLE_TOOLS_INIT ON) -endif() - -set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples") -set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools") - -if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) - set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT}) -else() - set(CUTLASS_ENABLE_TESTS_INIT OFF) -endif() - -set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests") - if(CUTLASS_ENABLE_TOOLS) add_subdirectory(tools) endif() diff --git a/CUDA.cmake b/CUDA.cmake index 6978a518..d1eb4dbc 100644 --- a/CUDA.cmake +++ b/CUDA.cmake @@ -39,23 +39,27 @@ if(CUTLASS_NATIVE_CUDA) enable_language(CUDA) + if(NOT CUDA_VERSION) + set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) + endif() + if(NOT CUDA_TOOLKIT_ROOT_DIR) + get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE) + endif() + else() find_package(CUDA REQUIRED) + # We workaround missing variables with the native flow by also finding the CUDA toolkit the old way. -endif() + if(NOT CMAKE_CUDA_COMPILER_VERSION) + set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION}) + endif() -if(NOT CUDA_VERSION) - set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) -endif() -if(NOT CUDA_TOOLKIT_ROOT_DIR) - get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE) endif() if (CUDA_VERSION VERSION_LESS 9.2) message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.") endif() - if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang") set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc) message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") @@ -74,7 +78,7 @@ find_library( # in the CUDA toolkit we're building against. ) -if(CUDART_LIBRARY) +if(NOT TARGET cudart AND CUDART_LIBRARY) message(STATUS "CUDART: ${CUDART_LIBRARY}") @@ -95,6 +99,10 @@ if(CUDART_LIBRARY) ${CUDART_LIBRARY} ) +elseif(TARGET cudart) + + message(STATUS "CUDART: Already Found") + else() message(STATUS "CUDART: Not Found") @@ -116,7 +124,7 @@ find_library( # in the CUDA toolkit we're building against. ) -if(CUDA_DRIVER_LIBRARY) +if(NOT TARGET cuda_driver AND CUDA_DRIVER_LIBRARY) message(STATUS "CUDA Driver: ${CUDA_DRIVER_LIBRARY}") @@ -137,6 +145,10 @@ if(CUDA_DRIVER_LIBRARY) ${CUDA_DRIVER_LIBRARY} ) +elseif(TARGET cuda_driver) + + message(STATUS "CUDA Driver: Already Found") + else() message(STATUS "CUDA Driver: Not Found") @@ -156,7 +168,7 @@ find_library( # in the CUDA toolkit we're building against. ) -if(NVRTC_LIBRARY) +if(NOT TARGET nvrtc AND NVRTC_LIBRARY) message(STATUS "NVRTC: ${NVRTC_LIBRARY}") @@ -177,6 +189,10 @@ if(NVRTC_LIBRARY) ${NVRTC_LIBRARY} ) +elseif(TARGET nvrtc) + + message(STATUS "NVRTC: Already Found") + else() message(STATUS "NVRTC: Not Found") @@ -197,48 +213,137 @@ function(cutlass_correct_source_file_language_property) endif() endfunction() -function(cutlass_add_library) +set(CUTLASS_UNITY_BUILD_ENABLED ON CACHE BOOL "Enable combined source compilation") +set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files") - set(options INTERFACE STATIC SHARED OBJECT) - set(oneValueArgs) +function(cutlass_unify_source_files TARGET_ARGS_VAR) + + set(options) + set(oneValueArgs BATCH_SOURCES BATCH_SIZE) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang" OR __INTERFACE) - cutlass_correct_source_file_language_property(${ARGN}) - add_library(${ARGN}) - else() - set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_library(${ARGN}) + if (NOT DEFINED TARGET_ARGS_VAR) + message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required") endif() + if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE) + set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE}) + endif() + + if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1) + + set(CUDA_FILE_ARGS) + set(TARGET_SOURCE_ARGS) + + foreach(ARG ${__UNPARSED_ARGUMENTS}) + if(${ARG} MATCHES ".*\.cu$") + list(APPEND CUDA_FILE_ARGS ${ARG}) + else() + list(APPEND TARGET_SOURCE_ARGS ${ARG}) + endif() + endforeach() + + list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS) + while(NUM_CUDA_FILE_ARGS GREATER 0) + list(SUBLIST CUDA_FILE_ARGS 0 ${__BATCH_SIZE} CUDA_FILE_BATCH) + string(SHA256 CUDA_FILE_BATCH_HASH "${CUDA_FILE_BATCH}") + string(SUBSTRING ${CUDA_FILE_BATCH_HASH} 0 12 CUDA_FILE_BATCH_HASH) + set(BATCH_FILE ${CMAKE_CURRENT_BINARY_DIR}/${NAME}.unity.${CUDA_FILE_BATCH_HASH}.cu) + message(STATUS "Generating ${BATCH_FILE}") + file(WRITE ${BATCH_FILE} "// Unity File - Auto Generated!\n") + foreach(CUDA_FILE ${CUDA_FILE_BATCH}) + get_filename_component(CUDA_FILE_ABS_PATH ${CUDA_FILE} ABSOLUTE) + file(APPEND ${BATCH_FILE} "#include \"${CUDA_FILE_ABS_PATH}\"\n") + endforeach() + list(APPEND TARGET_SOURCE_ARGS ${BATCH_FILE}) + if (NUM_CUDA_FILE_ARGS LESS_EQUAL __BATCH_SIZE) + break() + endif() + list(SUBLIST CUDA_FILE_ARGS ${__BATCH_SIZE} -1 CUDA_FILE_ARGS) + list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS) + endwhile() + + else() + + set(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) + + endif() + + set(${TARGET_ARGS_VAR} ${TARGET_SOURCE_ARGS} PARENT_SCOPE) + endfunction() -function(cutlass_add_executable) +function(cutlass_add_library NAME) set(options) - set(oneValueArgs) + set(oneValueArgs EXPORT_NAME) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) + if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") - cutlass_correct_source_file_language_property(${ARGN}) - add_executable(${ARGN}) + cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) + add_library(${NAME} ${TARGET_SOURCE_ARGS}) else() set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_executable(${ARGN}) + cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS}) + endif() + + cutlass_apply_standard_compile_options(${NAME}) + cutlass_apply_cuda_gencode_flags(${NAME}) + + target_compile_features( + ${NAME} + INTERFACE + cxx_std_11 + ) + + if(__EXPORT_NAME) + add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME}) + set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME}) endif() endfunction() -function(cutlass_target_sources) +function(cutlass_add_executable NAME) set(options) set(oneValueArgs) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cutlass_correct_source_file_language_property(${ARGN}) - target_sources(${ARGN}) + cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) + + if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") + cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) + add_executable(${NAME} ${TARGET_SOURCE_ARGS}) + else() + set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) + cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS}) + endif() + + cutlass_apply_standard_compile_options(${NAME}) + cutlass_apply_cuda_gencode_flags(${NAME}) + + target_compile_features( + ${NAME} + INTERFACE + cxx_std_11 + ) + +endfunction() + +function(cutlass_target_sources NAME) + + set(options) + set(oneValueArgs) + set(multiValueArgs) + cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) + cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) + target_sources(${NAME} ${TARGET_SOURCE_ARGS}) endfunction() diff --git a/README.md b/README.md index 3b5f4728..dd1c4c65 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 2.0 +# CUTLASS 2.1 -_CUTLASS 2.0 - November 2019_ +_CUTLASS 2.1 - April 2020_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA. @@ -25,6 +25,14 @@ NVIDIA's Volta and Turing architectures. See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly. +# What's New in CUTLASS 2.1 + +CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding: + +- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores +- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library) + + # What's New in CUTLASS 2.0 CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer: @@ -33,7 +41,7 @@ CUTLASS 2.0 is a substantial refactoring from the previous version, intended to - Robust and durable templates that reliably span the design space - Encapsulated functionality that may be reusable in other contexts -See the [CHANGELOG](CHANGELOG.md) for more details. +**See the [CHANGELOG](CHANGELOG.md) for more details.** See the [functionality listing](media/docs/functionality.md) for the list of operations supported at each level of the execution model hierarchy. @@ -82,7 +90,7 @@ any Maxwell-, Pascal-, Volta-, or Turing- architecture NVIDIA GPU. # Documentation -CUTLASS 2.0 is described in the following documents and the accompanying +CUTLASS 2.1 is described in the following documents and the accompanying [Doxygen documentation](https://nvidia.github.io/cutlass). - [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS @@ -177,7 +185,7 @@ include/ # client applications should target this directory ### CUTLASS SDK Examples -CUTLASS SDK examples apply CUTLASS templates to implement basic computations. +[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations. ``` examples/ @@ -198,12 +206,19 @@ examples/ 07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores 08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores + + 10_planar_complex/ # example demonstrating planar complex GEMM kernels + + 11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes ``` ### Tools ``` tools/ library/ # CUTLASS Instance Library - contains instantiations of all supported CUTLASS templates + include/ + cutlass/ + library/ profiler/ # CUTLASS Profiler - command-line utility for executing operations in the # CUTLASS Library @@ -279,7 +294,7 @@ The official list of CUTLASS developers and contributors is available here: [CON # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/cuBLAS.cmake b/cuBLAS.cmake index 60a56ca5..d7f330cf 100644 --- a/cuBLAS.cmake +++ b/cuBLAS.cmake @@ -1,7 +1,8 @@ message(STATUS "Configuring cublas ...") -if(DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) +if((DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) OR + (DEFINED CUBLAS_ENABLED AND NOT CUBLAS_ENABLED)) # Don't add cuBLAS if it's defined and false, assume it's not found. @@ -59,11 +60,13 @@ endif() if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas) if(WIN32) - add_library(cublas STATIC IMPORTED) + add_library(cublas STATIC IMPORTED GLOBAL) else() - add_library(cublas SHARED IMPORTED) + add_library(cublas SHARED IMPORTED GLOBAL) endif() + add_library(nvidia::cublas ALIAS cublas) + set_property( TARGET cublas PROPERTY IMPORTED_LOCATION @@ -88,23 +91,20 @@ if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas) ${CUBLAS_PATH}/lib/x64 /usr/lib/x86_64-linux-gnu) - if(_CUBLASLT_LIBRARY) + if(_CUBLASLT_LIBRARY AND NOT TARGET cublasLt) if(WIN32) - add_library(cublasLt STATIC IMPORTED) + add_library(cublasLt STATIC IMPORTED GLOBAL) else() - add_library(cublasLt SHARED IMPORTED) + add_library(cublasLt SHARED IMPORTED GLOBAL) endif() set_property( TARGET cublasLt PROPERTY IMPORTED_LOCATION ${_CUBLASLT_LIBRARY}) - - target_link_libraries( - cublas - INTERFACE - cublasLt) + + add_library(nvidia::cublasLt ALIAS cublasLt) endif() diff --git a/examples/06_splitK_gemm/splitk_gemm.cu b/examples/06_splitK_gemm/splitk_gemm.cu index f0ce9825..5fb513cb 100644 --- a/examples/06_splitK_gemm/splitk_gemm.cu +++ b/examples/06_splitK_gemm/splitk_gemm.cu @@ -173,14 +173,35 @@ using Gemm = cutlass::gemm::device::GemmSplitKParallel; int main() { + + // + // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1. + // + // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. + // + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; + return -1; + } + cudaDeviceProp props; - CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + if (!(props.major >= 7)) { std::cerr << "Volta Tensor Ops must be run on a machine with compute capability at least 70." << std::endl; - return 0; + return -1; } + // + // Define problem size + // + const int length_m = 5120; const int length_n = 4096; const int length_k = 4096; diff --git a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu index 424c90fc..447cc1cc 100644 --- a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu +++ b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu @@ -29,7 +29,7 @@ provided by CUTLASS using tensor cores; which we run on a NVIDIA Volta GPU. Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing high performance kernels at scale which works for multiple problem sizes with good abstractions is -really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose +really hard. CUTLASS solves this problem by providing simplified abstractions to compose multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU easily. @@ -189,13 +189,27 @@ using Gemm = cutlass::gemm::device::Gemm; int main() { + + // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1. + // + // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; + return -1; + } + cudaDeviceProp props; - CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } if (!(props.major >= 7)) { std::cerr << "Volta Tensor Ops must be run on a machine with compute capability at least 70." << std::endl; - return 0; + return -1; } const int length_m = 5120; diff --git a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu index 1628ce0a..3440d82f 100644 --- a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu +++ b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu @@ -29,7 +29,7 @@ provided by CUTLASS using tensor cores; which we run on a NVIDIA Turing GPU. Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing high performance kernels at scale which works for multiple problem sizes with good abstractions is -really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose +really hard. CUTLASS solves this problem by providing simplified abstractions to compose multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU easily. @@ -187,13 +187,28 @@ using Gemm = cutlass::gemm::device::Gemm; int main() { - cudaDeviceProp props; - CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major >= 7 && props.minor >= 5)) { - std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75." + // Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { + std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; + return -1; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!((props.major * 10 + props.minor) >= 75)) { + std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75." << std::endl; - return 0; + return -1; } const int length_m = 5120; diff --git a/examples/10_planar_complex/CMakeLists.txt b/examples/10_planar_complex/CMakeLists.txt new file mode 100644 index 00000000..555836ae --- /dev/null +++ b/examples/10_planar_complex/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +# Planar Complex GEMM example +cutlass_example_add_executable( + 10_planar_complex + planar_complex.cu +) + + +# +# This example depends on the CUTLASS Library +# + +target_link_libraries( + 10_planar_complex + PRIVATE + cutlass_lib + cutlass_tools_util_includes +) + diff --git a/examples/10_planar_complex/planar_complex.cu b/examples/10_planar_complex/planar_complex.cu new file mode 100644 index 00000000..7fc92870 --- /dev/null +++ b/examples/10_planar_complex/planar_complex.cu @@ -0,0 +1,551 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Planar Complex GEMM + + This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels supporting + the batched strided mode. + + These kernels represent complex matrices by storing the real and imaginary parts of the matrix in + disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts + as either column-major or row-major layouts with a single leading dimension indicating the stride + between columns or rows. + + The CUTLASS Library collects multiple template instantiations in a data structure and offers + a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures. + + CUTLASS decouples matrix layout from complex transformation, so four possible transformations + are possible on the A and B operands: + + n: column-major + c: column-major complex conjugate + t: row-major + h: row-major complex conjugate + + The CUTLASS Library contains many kernel instances specialized for architecture, data type, tile + size, and alignment. This can result in long compile times. + + To build strictly the planar complex kernels needed for general application, execute the following + CMake command in an empty build directory. + + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex + + This builds all planar complex GEMM variants for Volta and Turing architectures. + + To build strictly the kernels needed for this example, an even narrower filter string may be + specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for + the 'CN' layout configuration (conjugate A operand with both A and B as column-major). + + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_f16*cn + + $ make 10_planar_complex + + $ ./examples/10_planar_complex/10_planar_complex --m=2048 --n=1024 --k=512 --batch=10 +*/ + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor_planar_complex.h" + +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "cutlass/util/reference/device/gemm_planar_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cutlass/library/handle.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + cutlass::complex alpha; + cutlass::complex beta; + + bool reference_check; + int iterations; + + Options(): + help(false), + problem_size({1024, 1024, 1024}), + batch_count(1), + reference_check(true), + iterations(20), + alpha(1), + beta() { } + + bool valid() { + return true; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch", batch_count); + + cmd.get_cmd_line_argument("alpha", alpha.real()); + cmd.get_cmd_line_argument("alpha_i", alpha.imag()); + cmd.get_cmd_line_argument("beta", beta.real()); + cmd.get_cmd_line_argument("beta_i", beta.imag()); + + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "10_planar_complex example\n\n" + << " This example uses the CUTLASS Library to execute Planar Complex GEMM computations.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m GEMM M dimension\n" + << " --n GEMM N dimension\n" + << " --k GEMM K dimension\n" + << " --batch Number of GEMM operations executed in one batch\n" + << " --alpha Epilogue scalar alpha (real part)\n" + << " --alpha_i Epilogue scalar alpha (imaginary part)\n" + << " --beta Epilogue scalar beta (real part)\n\n" + << " --beta_i Epilogue scalar beta (imaginary part)\n\n" + << " --iterations Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/10_planar_complex/10_planar_complex --batch=7 --m=1024 --n=512 --k=1024 \\\n" + << " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = problem_size.product() * batch_count * 4; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performance test environment for planar complex +class TestbedPlanarComplex { +public: + + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::half_t; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementCompute = float; + using ElementAccumulator = float; + + // + // Data members + // + + cutlass::library::Handle handle; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + cutlass::DeviceAllocation tensor_A; + cutlass::DeviceAllocation tensor_B; + cutlass::DeviceAllocation tensor_C; + cutlass::DeviceAllocation tensor_D; + cutlass::DeviceAllocation tensor_D_ref; + + // + // Methods + // + + TestbedPlanarComplex( + Options const &options + ): + problem_size(options.problem_size), batch_count(options.batch_count) { + + // Allocate device memory for batched strided GEMM + tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2); + tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2); + tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + } + + void initialize() { + + uint64_t seed = 1073; + + // Use small integers to simplify correctness checking + int scope_max = 6; + int scope_min = -6; + + cutlass::reference::device::BlockFillRandomUniform( + tensor_A.get(), tensor_A.size(), seed, ElementA(scope_max), ElementA(scope_min), 0); + + cutlass::reference::device::BlockFillRandomUniform( + tensor_B.get(), tensor_B.size(), seed * 2019, ElementB(scope_max), ElementB(scope_min), 0); + + cutlass::reference::device::BlockFillRandomUniform( + tensor_C.get(), tensor_C.size(), seed * 2020, ElementC(scope_max), ElementC(scope_min), 0); + } + + Result profile(Options const &options) { + + Result result; + + initialize(); + + ElementA *ptr_A = tensor_A.get(); + ElementB *ptr_B = tensor_B.get(); + ElementC *ptr_C = tensor_C.get(); + ElementC *ptr_D = tensor_D.get(); + + int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2; + int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2; + int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2; + int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2; + + int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); + int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); + int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + + int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k(); + int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n(); + int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n(); + int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n(); + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMMs + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + + // + // Execute the planar complex GEMM kernel via the CUTLASS Library's + // dispatch routines. + // + // Note, for planar complex GEMM kernels, all numeric type arguments + // specify the data type of the base real types. These are understood to + // apply to planar complex representations of matrices in memory and to complex + // structures for scalars. + // + // See tools/library/include/cutlass/library/handle.h for more details. + // + + result.status = handle.gemm_planar_complex( + problem_size.m(), // GEMM M dimension + problem_size.n(), // GEMM N dimension + problem_size.k(), // GEMM K dimension + + cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation + cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars + + &options.alpha, // Pointer to alpha scalar, of type complex + + cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix + cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix + cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand + ptr_A, // Pointer to real part of A matrix + ptr_A + imag_stride_A, // Pointer to imaginary part of A matrix + lda, // Leading dimension of real part of A matrix + lda, // Leading dimension of imaginary part of A matrix + + cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix + cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix + cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand + ptr_B, // Pointer to real part of B matrix + ptr_B + imag_stride_B, // Pointer to imaginary part of B matrix + ldb, // Leading dimension of real part of B matrix + ldb, // Leading dimension of imaginary part of B matrix + + &options.beta, // Pointer to beta scalar, of type complex + + cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices + + ptr_C, // Pointer to real part of C matrix + ptr_C + imag_stride_C, // Pointer to imaginary part of C matrix + ldc, // Leading dimension of real part of C matrix + ldc, // Leading dimension of imaginary part of C matrix + + ptr_D, // Pointer to real part of D matrix + ptr_D + imag_stride_D, // Pointer to imaginary part of D matrix + ldd, // Leading dimension of real part of D matrix + ldd, // Leading dimension of imaginary part of D matrix + + batch_count, // Number of batched elements + + batch_stride_A, // Stride between batches of real parts of A matrix + batch_stride_A, // Stride between batches of imaginary parts of A matrix + + batch_stride_B, // Stride between batches of real parts of B matrix + batch_stride_B, // Stride between batches of imaginary parts of B matrix + + batch_stride_C, // Stride between batches of real parts of C matrix + batch_stride_C, // Stride between batches of imaginary parts of C matrix + + batch_stride_D, // Stride between batches of real parts of D matrix + batch_stride_D // Stride between batches of imaginary parts of D matrix + ); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS internal error - configuration not supported" << std::endl; + return result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + if (handle.get_last_operation()) { + std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl; + } + + // + // Compute reference in device code + // + + if (options.reference_check) { + + result.passed = true; + + for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) { + cutlass::reference::device::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + options.alpha, + {tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A}, + cutlass::ComplexTransform::kConjugate, + {tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B}, + cutlass::ComplexTransform::kNone, + options.beta, + {tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C}, + {tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D} + ); + + ElementC epsilon = 0.1_hf; + ElementC nonzero_floor = 0.1_hf; + + result.passed = cutlass::reference::device::BlockCompareRelativelyEqual( + tensor_D.get() + idx * batch_stride_D, + tensor_D_ref.get() + idx * batch_stride_D, + batch_stride_D, + epsilon, + nonzero_floor + ); + } + + if (result.passed) { + std::cout << "Reference check passed." << std::endl; + } + else { + std::cerr << "Error - reference check failed." << std::endl; + } + } + + std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + // Volta Tensor Core operations are first available in CUDA 10.1 Toolkit. + // + // Turing Tensor Core operations are first available in CUDA 10.2 Toolkit. + // + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < 7) { + std::cerr << "Volta Tensor Core operations must be run on a machine with compute capability at least 70." + << std::endl; + return -1; + } + else if (props.major == 7 && props.minor <= 2) { + // + // If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example. + // + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; + return -1; + } + } + else if (props.major == 7 && props.minor >= 5) { + // + // If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example. + // + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { + std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; + return -1; + } + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + TestbedPlanarComplex testbed(options); + + Result result = testbed.profile(options); + + return result.passed ? 0 : -1; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/11_planar_complex_array/CMakeLists.txt b/examples/11_planar_complex_array/CMakeLists.txt new file mode 100644 index 00000000..2a3f5987 --- /dev/null +++ b/examples/11_planar_complex_array/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +# Planar Complex Array GEMM example +cutlass_example_add_executable( + 11_planar_complex_array + planar_complex_array.cu +) + + +# +# This example depends on the CUTLASS Library +# + +target_link_libraries( + 11_planar_complex_array + PRIVATE + cutlass_lib + cutlass_tools_util_includes +) + diff --git a/examples/11_planar_complex_array/planar_complex_array.cu b/examples/11_planar_complex_array/planar_complex_array.cu new file mode 100644 index 00000000..3003a900 --- /dev/null +++ b/examples/11_planar_complex_array/planar_complex_array.cu @@ -0,0 +1,611 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Planar Complex Array Example + + This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels which + execute a batch of matrix products, loading problem sizes and matrix base pointers from arrays + in global memory. + + These kernels represent complex matrices by storing the real and imaginary parts of the matrix in + disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts + as either column-major or row-major layouts with a single leading dimension indicating the stride + between columns or rows. + + The CUTLASS Library collects multiple template instantiations in a data structure and offers + a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures. + + CUTLASS decouples matrix layout from complex transformation, so four possible transformations + are possible on the A and B operands: + + n: column-major + c: column-major complex conjugate + t: row-major + h: row-major complex conjugate + + To build strictly the planar complex kernels needed for general application, execute the following + CMake command in an empty build directory. + + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex + + This builds all planar complex GEMM variants for Volta and Turing architectures. + + To build strictly the kernels needed for this example, an even narrower filter string may be + specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for + the 'CN' layout configuration (conjugate A operand with both A and B as column-major). + + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_array_f16*cn + + $ make 11_planar_complex_array + + $ ./examples/11_planar_complex_array/11_planar_complex_array --m=2048 --n=1024 --k=512 --batch=10 +*/ + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor_planar_complex.h" + +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "cutlass/util/reference/device/gemm_planar_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cutlass/library/handle.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + cutlass::complex alpha; + cutlass::complex beta; + + bool reference_check; + int iterations; + + Options(): + help(false), + problem_size({1024, 1024, 1024}), + batch_count(1), + reference_check(true), + iterations(20), + alpha(1), + beta() { } + + bool valid() { + return true; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch", batch_count); + + cmd.get_cmd_line_argument("alpha", alpha.real()); + cmd.get_cmd_line_argument("alpha_i", alpha.imag()); + cmd.get_cmd_line_argument("beta", beta.real()); + cmd.get_cmd_line_argument("beta_i", beta.imag()); + + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "11_planar_complex_array example\n\n" + << " This example uses the CUTLASS Library to execute Planar Complex Array GEMM computations.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m GEMM M dimension\n" + << " --n GEMM N dimension\n" + << " --k GEMM K dimension\n" + << " --batch Number of GEMM operations executed in one batch\n" + << " --alpha Epilogue scalar alpha (real part)\n" + << " --alpha_i Epilogue scalar alpha (imaginary part)\n" + << " --beta Epilogue scalar beta (real part)\n\n" + << " --beta_i Epilogue scalar beta (imaginary part)\n\n" + << " --iterations Number of profiling iterations to perform.\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/11_planar_complex_array/11_planar_complex_array\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = problem_size.product() * batch_count * 4; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performance test environment for planar complex +class TestbedPlanarComplex { +public: + + // Half-precision input and output + using Element = cutlass::half_t; + + // Configurations for layouts and internal computation + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementCompute = float; + using ElementAccumulator = float; + + // + // Data members + // + + cutlass::library::Handle handle; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + cutlass::DeviceAllocation tensor_A; + cutlass::DeviceAllocation tensor_B; + cutlass::DeviceAllocation tensor_C; + cutlass::DeviceAllocation tensor_D; + cutlass::DeviceAllocation tensor_D_ref; + + cutlass::DeviceAllocation ptr_A_real; + cutlass::DeviceAllocation ptr_A_imag; + cutlass::DeviceAllocation ptr_B_real; + cutlass::DeviceAllocation ptr_B_imag; + cutlass::DeviceAllocation ptr_C_real; + cutlass::DeviceAllocation ptr_C_imag; + cutlass::DeviceAllocation ptr_D_real; + cutlass::DeviceAllocation ptr_D_imag; + + // + // Methods + // + + TestbedPlanarComplex( + Options const &options + ): + problem_size(options.problem_size), batch_count(options.batch_count) { + + // Allocate device memory for batched planar complex GEMM + tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2); + tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2); + tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); + + ptr_A_real.reset(batch_count); + ptr_A_imag.reset(batch_count); + ptr_B_real.reset(batch_count); + ptr_B_imag.reset(batch_count); + ptr_C_real.reset(batch_count); + ptr_C_imag.reset(batch_count); + ptr_D_real.reset(batch_count); + ptr_D_imag.reset(batch_count); + + } + + void initialize() { + + uint64_t seed = 1073; + + // Use small integers to simplify correctness checking + int scope_max = 6; + int scope_min = -6; + + cutlass::reference::device::BlockFillRandomUniform( + tensor_A.get(), tensor_A.size(), seed, Element(scope_max), Element(scope_min), 0); + + cutlass::reference::device::BlockFillRandomUniform( + tensor_B.get(), tensor_B.size(), seed * 2019, Element(scope_max), Element(scope_min), 0); + + cutlass::reference::device::BlockFillRandomUniform( + tensor_C.get(), tensor_C.size(), seed * 2020, Element(scope_max), Element(scope_min), 0); + } + + Result profile(Options const &options) { + + Result result; + + initialize(); + + Element *ptr_A = tensor_A.get(); + Element *ptr_B = tensor_B.get(); + Element *ptr_C = tensor_C.get(); + Element *ptr_D = tensor_D.get(); + + int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2; + int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2; + int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2; + int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2; + + int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); + int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); + int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + + int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k(); + int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n(); + int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n(); + int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n(); + + // + // Configure pointers in global memory + // + + struct { + Element *base; + void **ptr_real; + void **ptr_imag; + int64_t batch_stride; + int64_t imag_stride; + } tensors[] = { + { tensor_A.get(), ptr_A_real.get(), ptr_A_imag.get(), batch_stride_A, imag_stride_A}, + { tensor_B.get(), ptr_B_real.get(), ptr_B_imag.get(), batch_stride_B, imag_stride_B}, + { tensor_C.get(), ptr_C_real.get(), ptr_C_imag.get(), batch_stride_C, imag_stride_C}, + { tensor_D.get(), ptr_D_real.get(), ptr_D_imag.get(), batch_stride_D, imag_stride_D} + }; + + for (auto const &tensor : tensors) { + for (int idx = 0; idx < batch_count; ++idx) { + + void *ptr_real = tensor.base + idx * tensor.batch_stride; + void *ptr_imag = tensor.base + idx * tensor.batch_stride + tensor.imag_stride; + + cudaError_t error = cudaMemcpy( + tensor.ptr_real + idx, + &ptr_real, + sizeof(void *), + cudaMemcpyHostToDevice); + + if (error != cudaSuccess) { + throw std::runtime_error("Failed to copy pointer to device memory"); + } + + error = cudaMemcpy( + tensor.ptr_imag + idx, + &ptr_imag, + sizeof(void *), + cudaMemcpyHostToDevice); + + if (error != cudaSuccess) { + throw std::runtime_error("Failed to copy pointer to device memory"); + } + } + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + + // + // Execute the planar complex array GEMM kernel via the CUTLASS Library's + // dispatch routines. + // + // Note, for planar complex array GEMM kernels, all numeric type arguments + // specify the data type of the base real types. These are understood to + // apply to planar complex representations of matrices in memory and to complex + // structures for scalars. + // + // See tools/library/include/cutlass/library/handle.h for more details. + // + + result.status = handle.gemm_planar_complex_array( + + problem_size.m(), // expected GEMM M dimension + problem_size.n(), // expected GEMM N dimension + problem_size.k(), // expected GEMM K dimension + batch_count, // Number of batched elements + + nullptr, + nullptr, + nullptr, + + cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation + cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars + + &options.alpha, // Pointer to alpha scalar, of type complex + + cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix + cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix + cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand + + ptr_A_real.get(), // Pointer to array of pointers to real part of A matrix + ptr_A_imag.get(), // Pointer to array of pointers to imaginary part of A matrix + + lda, // Leading dimension of real part of A matrix + lda, // Leading dimension of imaginary part of A matrix + + cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix + cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix + cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand + + ptr_B_real.get(), // Pointer to array of pointers to real part of B matrix + ptr_B_imag.get(), // Pointer to array of pointers to imaginary part of B matrix + + ldb, // Leading dimension of real part of B matrix + ldb, // Leading dimension of imaginary part of B matrix + + &options.beta, // Pointer to beta scalar, of type complex + + cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices + + ptr_C_real.get(), // Pointer to array of pointers to real part of C matrix + ptr_C_imag.get(), // Pointer to array of pointers to imaginary part of C matrix + + ldc, // Leading dimension of real part of C matrix + ldc, // Leading dimension of imaginary part of C matrix + + ptr_D_real.get(), // Pointer to array of pointers to real part of D matrix + ptr_D_imag.get(), // Pointer to array of pointers to imaginary part of D matrix + + ldd, // Leading dimension of real part of D matrix + ldd // Leading dimension of imaginary part of D matrix + ); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS internal error - configuration not supported" << std::endl; + return result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + if (handle.get_last_operation()) { + std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl; + } + + // + // Compute reference in device code + // + + if (options.reference_check) { + + result.passed = true; + + for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) { + cutlass::reference::device::GemmPlanarComplex< + Element, LayoutA, + Element, LayoutB, + Element, LayoutC, + ElementAccumulator + >( + problem_size, + options.alpha, + {tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A}, + cutlass::ComplexTransform::kConjugate, + {tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B}, + cutlass::ComplexTransform::kNone, + options.beta, + {tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C}, + {tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D} + ); + + Element epsilon = 0.1_hf; + Element nonzero_floor = 0.1_hf; + + result.passed = cutlass::reference::device::BlockCompareRelativelyEqual( + tensor_D.get() + idx * batch_stride_D, + tensor_D_ref.get() + idx * batch_stride_D, + batch_stride_D, + epsilon, + nonzero_floor + ); + } + + if (result.passed) { + std::cout << "Reference check passed." << std::endl; + } + else { + std::cerr << "Error - reference check failed." << std::endl; + } + } + + std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + // Volta Tensor Core operations are first available in CUDA 10.1 Toolkit. + // + // Turing Tensor Core operations are first available in CUDA 10.2 Toolkit. + // + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < 7) { + std::cerr << "Tensor Core operations must be run on a machine with compute capability at least 70." + << std::endl; + return -1; + } + else if (props.major == 7 && props.minor <= 2) { + // + // If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example. + // + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; + return -1; + } + } + else if (props.major == 7 && props.minor >= 5) { + // + // If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example. + // + if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { + std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; + return -1; + } + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + TestbedPlanarComplex testbed(options); + + Result result = testbed.profile(options); + + return result.passed ? 0 : -1; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e434cd7f..d5c503e9 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -22,16 +22,14 @@ set(CUTLASS_EXAMPLES_COMMON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/common) -function(cutlass_example_add_executable) +function(cutlass_example_add_executable NAME) set(options) set(oneValueArgs) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cutlass_add_executable(${__UNPARSED_ARGUMENTS}) - - list(GET __UNPARSED_ARGUMENTS 0 NAME) + cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS}) target_link_libraries( ${NAME} @@ -59,7 +57,10 @@ foreach(EXAMPLE 05_batched_gemm 06_splitK_gemm 07_volta_tensorop_gemm - 08_turing_tensorop_gemm) + 08_turing_tensorop_gemm + 10_planar_complex + 11_planar_complex_array +) add_subdirectory(${EXAMPLE}) add_dependencies(cutlass_examples ${EXAMPLE}) diff --git a/include/cutlass/arch/memory_sm75.h b/include/cutlass/arch/memory_sm75.h index c821ddaf..195f8abf 100644 --- a/include/cutlass/arch/memory_sm75.h +++ b/include/cutlass/arch/memory_sm75.h @@ -46,61 +46,68 @@ inline __device__ void ldsm(Array & D, void const* ptr); ///////////////////////////////////////////////////////////////////////////////////////////////// // -// Specializations +// Determine the appropriate way to target PTX's "ldmatrix" instruction. // ///////////////////////////////////////////////////////////////////////////////////////////////// -#if (__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ == 2) - #define CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED 1 -#else - #define CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED 0 -#endif - -#if ! defined(CUDA_NVVM_GET_SHARED_POINTER_ENABLED) - #define CUDA_NVVM_GET_SHARED_POINTER_ENABLED (CUDA_NVVM_GET_SHARED_POINTER_SUPPORTED) -#endif - #if ! defined(CUDA_LDMATRIX_SUPPORTED) #define CUDA_LDMATRIX_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 2)) #endif #if ! defined(CUDA_LDMATRIX_ENABLED) - #define CUDA_LDMATRIX_ENABLED (CUDA_LDMATRIX_SUPPORTED) + #define CUDA_LDMATRIX_ENABLED CUDA_LDMATRIX_SUPPORTED #endif -#if (CUDA_LDMATRIX_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +#if CUDA_LDMATRIX_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) #define CUDA_LDMATRIX_ACTIVATED 1 -#else - #define CUDA_LDMATRIX_ACTIVATED 0 #endif -#if defined(CUTLASS_GET_SMEM_POINTER) - // Use the existing implementation -#elif CUDA_NVVM_GET_SHARED_POINTER_ENABLED - #if ! defined(NVVM_GET_SMEM_POINTER) - #define NVVM_GET_SMEM_POINTER +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) + #define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1)) +#endif + +#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED) + #define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED +#endif + +#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED extern "C" { - // - // This NVVM intrinsic is subject to change in future versions of CUDA. - // Clients should not call it directly. Rather, they should use the - // cutlass::arch::ldsm<>() template. - // - __device__ uint32_t __nvvm_get_smem_pointer(void*); + // + // This NVVM intrinsic is subject to change in future versions of CUDA. + // Clients should not call it directly. Rather, they should use the + // cutlass::arch::ldsm<>() template. + // + __device__ uint32_t __nvvm_get_smem_pointer(void *); } - #endif - #define CUTLASS_GET_SMEM_POINTER(ptr) __nvvm_get_smem_pointer((void*)ptr) #endif ///////////////////////////////////////////////////////////////////////////////////////////////// +#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED + + /// CUTLASS helper to get SMEM pointer + inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) { + return __nvvm_get_smem_pointer(const_cast(ptr)); + } + + /// CUTLASS helper to get SMEM pointer + inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) { + return __nvvm_get_smem_pointer(ptr); + } + +#endif +///////////////////////////////////////////////////////////////////////////////////////////////// + template <> inline __device__ void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUDA_LDMATRIX_ACTIVATED) - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x; asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); @@ -120,9 +127,9 @@ inline __device__ void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUDA_LDMATRIX_ACTIVATED) - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x, y; asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); @@ -142,9 +149,9 @@ inline __device__ void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUDA_LDMATRIX_ACTIVATED) - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x, y, z, w; asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); @@ -167,9 +174,10 @@ template <> inline __device__ void ldsm( Array & D, void const* ptr) { + #if CUDA_LDMATRIX_ACTIVATED - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x; asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); @@ -189,9 +197,9 @@ inline __device__ void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUDA_LDMATRIX_ACTIVATED) - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x, y; asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); @@ -211,9 +219,9 @@ inline __device__ void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUDA_LDMATRIX_ACTIVATED) - unsigned addr = CUTLASS_GET_SMEM_POINTER(ptr); + unsigned addr = cutlass_get_smem_pointer(ptr); int x, y, z, w; asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index 6898f512..e59b710f 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -30,7 +30,9 @@ #include "cutlass/array.h" #include "cutlass/numeric_types.h" + #include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -49,6 +51,11 @@ struct OpMultiplyAddSaturate; ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Tag indicating the complex multiply-add operation +struct OpMultiplyAddComplex; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Tag indicating the inner product is defined by (XOR, POPC) struct OpXorPopc; diff --git a/include/cutlass/arch/mma_sm70.h b/include/cutlass/arch/mma_sm70.h index 90721f0d..57b50e00 100644 --- a/include/cutlass/arch/mma_sm70.h +++ b/include/cutlass/arch/mma_sm70.h @@ -27,7 +27,11 @@ */ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "mma.h" #include "cutlass/layout/matrix.h" @@ -84,6 +88,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; CUTLASS_HOST_DEVICE void operator()( @@ -139,6 +144,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; CUTLASS_HOST_DEVICE void operator()( @@ -194,6 +200,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; CUTLASS_HOST_DEVICE void operator()( @@ -249,6 +256,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; CUTLASS_HOST_DEVICE void operator()( @@ -310,6 +318,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; /// Multiply-add CUTLASS_HOST_DEVICE @@ -385,6 +394,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; /// Multiply-add CUTLASS_HOST_DEVICE @@ -460,6 +470,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; /// Multiply-add CUTLASS_HOST_DEVICE @@ -535,6 +546,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; /// Multiply-add CUTLASS_HOST_DEVICE diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index ee9599b0..fb8a3dc5 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -28,7 +28,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/arch/wmma.h" @@ -93,6 +97,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; CUTLASS_HOST_DEVICE void operator()( @@ -154,6 +159,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -215,6 +221,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -271,6 +278,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -327,6 +335,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -384,6 +393,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -446,6 +456,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -502,6 +513,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -558,6 +570,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -614,6 +627,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -676,6 +690,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -732,6 +747,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -788,6 +804,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -844,6 +861,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -906,6 +924,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -962,6 +981,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -1018,6 +1038,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -1074,6 +1095,7 @@ struct Mma< using FragmentC = Array; using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE @@ -1136,6 +1158,7 @@ struct Mma< using FragmentC = Array; using Operator = OpXorPopc; + using ArchTag = arch::Sm75; /// Computes multiply-add CUTLASS_HOST_DEVICE diff --git a/include/cutlass/arch/wmma.h b/include/cutlass/arch/wmma.h index b2f8d1eb..9843e134 100644 --- a/include/cutlass/arch/wmma.h +++ b/include/cutlass/arch/wmma.h @@ -68,24 +68,6 @@ namespace cutlass { namespace arch { -///////////////////////////////////////////////////////////////////////////////////////////////// -/// MemoryKind class (Shared vs. Global memory) -///////////////////////////////////////////////////////////////////////////////////////////////// -enum class MemoryKind { - kShared, // Data resides in shared memory - kGlobal // Data resides in global memory -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// WarpParams holds architecture-specific constants -///////////////////////////////////////////////////////////////////////////////////////////////// -struct WarpParams { - static int const kThreadsPerWarp = 32; - static int const kQuadsPerWarp = 8; - static int const kThreadsPerQuad = 4; -}; - //////////////////////////////////////////////////////////////////////////////////////////////// /// Statically maps cutlass data types => nvcuda::wmma data types ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -196,7 +178,6 @@ template < struct Wmma; ///////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace arch } // namespace cutlass diff --git a/include/cutlass/arch/wmma_ptx.h b/include/cutlass/arch/wmma_ptx.h deleted file mode 100644 index 63614286..00000000 --- a/include/cutlass/arch/wmma_ptx.h +++ /dev/null @@ -1,105 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing warp matrix multiply-add (WMMA) operations -*/ -#pragma once - -#include "cutlass/arch/wmma.h" - -namespace cutlass { -namespace arch { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// -/// WMMA structures to enclose * PTX * instruction string -/// -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// WMMA PTX string load for A, B, and C matrices -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename Element_, ///< Data type of elements - typename Layout_, ///< Layout of matrix (concept: MatrixLayout) - MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory -> -struct PtxWmmaLoadA; -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename Element_, ///< Data type of elements - typename Layout_, ///< Layout of matrix (concept: MatrixLayout) - MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory -> -struct PtxWmmaLoadB; -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename Element_, ///< Data type of elements - typename Layout_, ///< Layout of matrix (concept: MatrixLayout) - MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory -> -struct PtxWmmaLoadC; -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// WMMA Matrix multiply-add operation -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename ElementA_, ///< Data type of A elements - typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout) - typename ElementB_, ///< Data type of B elements - typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout) - typename ElementC_, ///< Element type of C matrix - typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout) - typename Operator = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc) -> -struct PtxWmma; -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// WMMA store for matrix D -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Shape_, ///< Size of the matrix product (concept: GemmShape) - typename Element_, ///< Data type of elements - typename Layout_, ///< Layout of matrix (concept: MatrixLayout) - MemoryKind Memory = MemoryKind::kShared ///< Data resides in shared or global memory -> -struct PtxWmmaStoreD; -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace arch -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma_sm70.h b/include/cutlass/arch/wmma_sm70.h index 63363ed7..6c989c9a 100644 --- a/include/cutlass/arch/wmma_sm70.h +++ b/include/cutlass/arch/wmma_sm70.h @@ -28,7 +28,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// @@ -68,6 +72,7 @@ struct Wmma< using ElementC = ElementC_; using LayoutC = LayoutC_; using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm70; // check supported wmma shape for the given multiplicand data types static_assert( diff --git a/include/cutlass/arch/wmma_sm72.h b/include/cutlass/arch/wmma_sm72.h index c5c15e9d..477a72c3 100644 --- a/include/cutlass/arch/wmma_sm72.h +++ b/include/cutlass/arch/wmma_sm72.h @@ -28,7 +28,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// @@ -65,6 +69,7 @@ struct Wmma< using ElementC = int32_t; using LayoutC = LayoutC_; using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm72; // check supported wmma shape for the given multiplicand data types static_assert( @@ -145,6 +150,7 @@ struct Wmma< using ElementC = int32_t; using LayoutC = LayoutC_; using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm72; // check supported wmma shape for the given multiplicand data types static_assert( diff --git a/include/cutlass/arch/wmma_sm75.h b/include/cutlass/arch/wmma_sm75.h index a1bccbfc..2985be58 100644 --- a/include/cutlass/arch/wmma_sm75.h +++ b/include/cutlass/arch/wmma_sm75.h @@ -28,7 +28,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// @@ -65,6 +69,7 @@ struct Wmma< using ElementC = int32_t; using LayoutC = LayoutC_; using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm75; // check supported wmma shape for the given multiplicand data types static_assert( @@ -143,6 +148,7 @@ struct Wmma< using ElementC = int32_t; using LayoutC = LayoutC_; using Operator = cutlass::arch::OpXorPopc; + using ArchTag = arch::Sm75; // check supported wmma shape for the given multiplicand data types static_assert( diff --git a/include/cutlass/array_planar_complex.h b/include/cutlass/array_planar_complex.h new file mode 100644 index 00000000..e2dbbc47 --- /dev/null +++ b/include/cutlass/array_planar_complex.h @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Array holding planar complex elements +template +struct ArrayPlanarComplex { + + /// Underlying real element + using Element = Element_; + + /// Number of logical elements + static size_t const kElements = N; + + /// Underlying Fragment of real-valued elemenets + using ArrayReal = Array; + +public: + + /// Fragment of real-valued elements representing the real part + ArrayReal real; + + /// Fragment of real-valued elements representing the imaginary part + ArrayReal imag; + +public: + + /// Ctor + CUTLASS_HOST_DEVICE + ArrayPlanarComplex() { } + + /// Ctor + CUTLASS_HOST_DEVICE + ArrayPlanarComplex( + ArrayReal const &real_, + ArrayReal const &imag_ + ): + real(real_), imag(imag_) { } + + /// Sets the array to zero efficiently + CUTLASS_HOST_DEVICE + void clear() { + real.clear(); + imag.clear(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to deduce template arguments +template +CUTLASS_HOST_DEVICE +ArrayPlanarComplex +make_ArrayPlanarComplex(Array const &real, Array const &imag) { + return ArrayPlanarComplex(real, imag); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index b479d310..20c4a64a 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -25,7 +25,11 @@ #pragma once #include +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/cutlass.h" #include "cutlass/half.h" @@ -351,6 +355,16 @@ CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) static_cast(imag(z)) * static_cast(imag(z)); } +/// Returns the complex conjugate +CUTLASS_HOST_DEVICE float conj(float const &z) { + return z; +} + +/// Returns the complex conjugate +CUTLASS_HOST_DEVICE double conj(double const &z) { + return z; +} + /// Returns the complex conjugate template CUTLASS_HOST_DEVICE complex conj(complex const &z) { @@ -414,6 +428,10 @@ CUTLASS_HOST_DEVICE complex sin(complex const &z) { template struct RealType< complex > { using Type = T; + + static complex from_real(double x) { + return complex(static_cast(x)); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -438,5 +456,18 @@ cutlass::complex from_real >(double r) { ////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct is_complex { + static bool const value = false; +}; + +template +struct is_complex> { + static bool const value = true; +}; + +////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass +////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index 7f40ede3..e2615755 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -28,6 +28,12 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + #include "cutlass/cutlass.h" namespace cutlass { diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index d50b2511..b5a0e5f4 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -35,10 +35,21 @@ namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#else +#define CUTLASS_HOST_DEVICE inline +#endif + /// Status code returned by CUTLASS operations enum class Status { kSuccess, ///< Operation was successful. kErrorMisalignedOperand, ///< operands fail alignment requirements. + kErrorInvalidDataType, ///< DataType fails requirement. kErrorInvalidLayout, ///< Layout fails alignment requirement. kErrorInvalidProblem, ///< Specified problem size is not supported by operator. kErrorNotSupported, ///< Operation is not supported on current device. @@ -48,12 +59,15 @@ enum class Status { }; /// Convert cutlass status to status strings -static inline char const* cutlassGetStatusString(cutlass::Status status) { +CUTLASS_HOST_DEVICE +static char const* cutlassGetStatusString(cutlass::Status status) { switch (status) { case cutlass::Status::kSuccess: return "Success"; case cutlass::Status::kErrorMisalignedOperand: return "Error Misaligned Operand"; + case cutlass::Status::kErrorInvalidDataType: + return "Error Invalid Data Type"; case cutlass::Status::kErrorInvalidLayout: return "Error Invalid Layout"; case cutlass::Status::kErrorInvalidProblem: @@ -79,16 +93,6 @@ static inline char const* cutlassGetStatusString(cutlass::Status status) { //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ -#define CUTLASS_DEVICE __forceinline__ __device__ -#elif defined(__CUDACC_RTC__) -#define CUTLASS_HOST_DEVICE __forceinline__ __device__ -#define CUTLASS_DEVICE __forceinline__ __device__ -#else -#define CUTLASS_HOST_DEVICE inline -#endif - #define CUTLASS_ASSERT(x) assert(x) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -115,6 +119,12 @@ static inline char const* cutlassGetStatusString(cutlass::Status status) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct Debug { + typename T::X x; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// static const int NUM_THREADS_PER_WARP = 32; static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2; @@ -131,6 +141,14 @@ int LaneId() { return ret; } +/// Computes SM number the thread is running on +CUTLASS_DEVICE +int SmId() { + int ret; + asm ("mov.u32 %0, %%smid;" : "=r"(ret)); + return ret; +} + #endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index 75843b38..9fe4b2b3 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -167,23 +167,23 @@ public: intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X /// Clamping constant value - ElementCompute const kClamp = ElementCompute(1 << (sizeof_bits::value - 1)); - - intermediate = max_accumulator(intermediate, -kClamp); - intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1)); + ElementCompute const kClamp = + ElementCompute((1U << (sizeof_bits::value - 1)) - 1); + + intermediate = max_accumulator(intermediate, -kClamp - ElementCompute(1)); + intermediate = min_accumulator(intermediate, kClamp); // Convert to destination numeric type NumericArrayConverter destination_converter; return destination_converter(intermediate); } - }; ///////////////////////////////////////////////////////////////////////////////////////////////// // Conditional guards to enable partial specialization for packed integers -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) /// Applies a linear combination operator to an array of elements then clamps the output before /// converting to the output element type. @@ -320,7 +320,142 @@ public: #endif // Conditional guards to enable partial specialization for packed integers -///////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements then clamps +/// the output before converting to the output element type. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +/// Note: The below method only works for small k dimensions. The default +/// approach is above +/// TODO: Add logic to fallback to the default approach +template < + /// Data type used to load and store< tensors + typename ElementOutput_, + /// Number of elements computed per operation + int Count, + /// Rounding mode + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class FastLinearCombinationClamp { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = int; + using ElementCompute = float; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + /// scales accumulators + ElementCompute alpha; + /// scales source tensor + ElementCompute beta; + /// pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *alpha_ptr; + /// pointer to source scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha, ElementCompute beta) + : alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) + : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} + }; + + private: + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + FastLinearCombinationClamp(Params const ¶ms) { + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return beta_ != ElementCompute(0); } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition) { + if (k_partition) { + beta_ = ElementCompute(1); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator, + FragmentOutput const &source, + ElementCompute uniform = ElementCompute(0)) const { + // Convert source to interal compute numeric type + FastNumericArrayConverter + source_converter; + FastNumericArrayConverter + accumulator_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Compute linear scaling in floating point + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + minimum min_accumulator; + maximum max_accumulator; + + // Float min-max + intermediate = + mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, + intermediate); // D = alpha * Accum + X + + /// Clamping constant value + ElementCompute const kClamp = + ElementCompute(1 << (sizeof_bits::value - 1)); + + intermediate = max_accumulator(intermediate, -kClamp); + intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1)); + + // Convert to destination numeric type + FastNumericArrayConverter + destination_converter; + + return destination_converter(intermediate); + } +}; + +//////////////////////////////////////////////////////////////////////////////// } // namespace thread } // namespace epilogue diff --git a/include/cutlass/epilogue/thread/linear_combination_planar_complex.h b/include/cutlass/epilogue/thread/linear_combination_planar_complex.h new file mode 100644 index 00000000..bfe6be78 --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_planar_complex.h @@ -0,0 +1,196 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations on planar-complex arrays +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/complex.h" +#include "cutlass/array_planar_complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to arrays of planar-complex elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +/// Note, as with most CUTLASS components for planar complex, the template arguments describe +/// the underlying real data type. +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LinearCombinationPlanarComplex { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = ArrayPlanarComplex; + using FragmentAccumulator = ArrayPlanarComplex; + using ComputeFragment = ArrayPlanarComplex; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + complex alpha; ///< scales accumulators + complex beta; ///< scales source tensor + complex const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + complex const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + complex alpha, + complex beta + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + complex const *alpha_ptr, + complex const *beta_ptr + ): alpha(complex()), beta(complex()), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + complex alpha_; + complex beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationPlanarComplex(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_.real() != ElementCompute(0) || beta_.imag() != ElementCompute(0); + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition) { + if (k_partition) { + beta_ = ElementCompute(1); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_source( + source_converter(source.real), + source_converter(source.imag)); + + ComputeFragment converted_accumulator( + accumulator_converter(accumulator.real), + accumulator_converter(accumulator.imag)); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies > mul_op; + multiply_add > mul_add_op; + + // complex multiply: I = beta * C + intermediate.real = mul_op(beta_.real(), converted_source.real); + intermediate.imag = mul_op(beta_.real(), converted_source.imag); + + intermediate.real = mul_add_op(-beta_.imag(), converted_source.imag, intermediate.real); + intermediate.imag = mul_add_op( beta_.imag(), converted_source.real, intermediate.imag); + + // complex multiply-add: I = alpha * AB + I + intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real, intermediate.real); + intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag, intermediate.imag); + + intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real); + intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return FragmentOutput( + destination_converter(intermediate.real), + destination_converter(intermediate.imag)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index f0514d4e..9afeb3eb 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -180,9 +180,13 @@ public: } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// +// Conditional guards to enable partial specialization for packed integers +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \ + ((__CUDACC_VER_MAJOR__ > 10) || \ + ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) + /// Applies a linear combination operator to an array of elements then clamps the output before /// converting to the output element type. /// @@ -327,6 +331,8 @@ public: } }; +#endif // Conditional guards to enable partial specialization for packed integers + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h index d11c623d..c3c40bab 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h @@ -58,16 +58,25 @@ namespace cutlass { namespace epilogue { namespace threadblock { -//////////////////////////////////////////////////////////////////////////////// - -/// Defines sensible defaults for epilogues for TensorOps. +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Specialization and defines sensible defaults for epilogues for complex*complex case +// 4 real-valued mma operations (Complex) +// A = (ar + j ai), B (br +j bi), D = AB +// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) +///////////////////////////////////////////////////////////////////////////////////////////////// template < + /// Epilouge Shape typename Shape_, + /// Warp-level mma operator typename WarpMmaTensorOp_, + /// Number of k partitions int PartitionsK, + /// Epilogue output operator typename OutputOp_, - int ElementsPerAccess -> + /// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load() + int ElementsPerAccess, + /// Multiply-add operator + typename Operator_ = arch::OpMultiplyAddComplex> struct DefaultEpilogueComplexTensorOp { using Shape = Shape_; @@ -75,6 +84,7 @@ struct DefaultEpilogueComplexTensorOp { static int const kPartitionsK = PartitionsK; using OutputOp = OutputOp_; static int const kElementsPerAccess = ElementsPerAccess; + using Operator = Operator_; using ElementOutput = typename OutputOp::ElementOutput; using LayoutC = typename WarpMmaTensorOp::LayoutC; diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h b/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h new file mode 100644 index 00000000..081bcbac --- /dev/null +++ b/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Constructs a default epilogue for planar complex outputs. + + This template reuses components for real-valued epilogues and applies them to planar complex + output matrices. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/array_planar_complex.h" + +#include "cutlass/arch/arch.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" + +#include "cutlass/epilogue/threadblock/epilogue_planar_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues. +template < + typename ThreadblockShape_, + typename WarpMma_, + typename OpcodeClass_, + typename ArchTag_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpiloguePlanarComplex; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues. +template < + typename ThreadblockShape_, + typename WarpMmaOperator_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + arch::OpClassTensorOp, + arch::Sm70, + PartitionsK, + OutputOp_, + ElementsPerAccess> { + + using RealEpilogue = DefaultEpilogueVoltaTensorOp< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + OutputOp_, + ElementsPerAccess + >; + + using Epilogue = EpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + typename RealEpilogue::OutputTileIterator, + typename RealEpilogue::AccumulatorFragmentIterator, + typename RealEpilogue::WarpTileIterator, + typename RealEpilogue::SharedLoadIterator, + OutputOp_, + typename RealEpilogue::Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues. +template < + typename ThreadblockShape_, + typename WarpMmaOperator_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + arch::OpClassTensorOp, + arch::Sm75, + PartitionsK, + OutputOp_, + ElementsPerAccess> { + + using RealEpilogue = DefaultEpilogueTensorOp< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + OutputOp_, + ElementsPerAccess + >; + + using Epilogue = EpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + typename RealEpilogue::OutputTileIterator, + typename RealEpilogue::AccumulatorFragmentIterator, + typename RealEpilogue::WarpTileIterator, + typename RealEpilogue::SharedLoadIterator, + OutputOp_, + typename RealEpilogue::Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues. +template < + typename ThreadblockShape_, + typename WarpMmaOperator_, + typename ArchTag_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + arch::OpClassSimt, + ArchTag_, + PartitionsK, + OutputOp_, + ElementsPerAccess> { + + using RealEpilogue = DefaultEpilogueSimt< + ThreadblockShape_, + WarpMmaOperator_, + OutputOp_, + ElementsPerAccess + >; + + using Epilogue = EpiloguePlanarComplex< + ThreadblockShape_, + WarpMmaOperator_, + PartitionsK, + typename RealEpilogue::OutputTileIterator, + typename RealEpilogue::AccumulatorFragmentIterator, + typename RealEpilogue::WarpTileIterator, + typename RealEpilogue::SharedLoadIterator, + OutputOp_, + typename RealEpilogue::Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h index 1fd9f7a5..f0435e92 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h @@ -23,7 +23,7 @@ * **************************************************************************************************/ /*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + \brief Epilogue for threadblock scoped GEMMs using WMMA. The epilogue rearranges the result of a matrix product through shared memory to match canonical tensor layouts in global memory. Epilogues support conversion and reduction operations. diff --git a/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h b/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h index 74910789..6f4bd2ad 100644 --- a/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h @@ -146,6 +146,54 @@ struct DefaultInterleavedThreadMapTensorOp { //////////////////////////////////////////////////////////////////////////////// +/// Defines the optimal thread map for TensorOp accumulator layouts +template +struct DefaultInterleavedConvThreadMapTensorOp { + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + static int const kPartitionsK = PartitionsK; + using Element = Element_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kInterleavedK = InterleavedK; + + // + // Definitions + // + + struct Detail { + /// Tensor Operations fundamentally perform operations on 8 rows + static int const kTensorOpRows = 8; + static int const kWarpSize = 32; + + static_assert(!(ThreadblockShape::kM % WarpShape::kM) && + !(ThreadblockShape::kM % WarpShape::kM), + "Divisibility"); + + /// Number of warps + using WarpCount = + gemm::GemmShape; + + /// Number of participating threads + static int const kThreads = WarpCount::kCount * kWarpSize; + }; + + // + // ThreadMap + // + + /// ThreadMap to be used by epilogue::MaskedTileIterator satisfying concept + /// InterleavedOutputTileThreadMap + using Type = InterleavedConvOutputTileThreadMap< + MatrixShape, + MatrixShape, + Detail::kThreads, kElementsPerAccess, sizeof_bits::value>; +}; + +//////////////////////////////////////////////////////////////////////////////// + } // namespace threadblock } // namespace epilogue } // namespace cutlass diff --git a/include/cutlass/epilogue/threadblock/epilogue.h b/include/cutlass/epilogue/threadblock/epilogue.h index b8e1e0ca..fe6877aa 100644 --- a/include/cutlass/epilogue/threadblock/epilogue.h +++ b/include/cutlass/epilogue/threadblock/epilogue.h @@ -32,7 +32,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -74,7 +78,7 @@ template < class Epilogue : public EpilogueBase< Shape_, - WarpMmaOperator_, + typename WarpMmaOperator_::Shape, PartitionsK, AccumulatorFragmentIterator_, WarpTileIterator_, @@ -84,7 +88,7 @@ public: using Base = EpilogueBase< Shape_, - WarpMmaOperator_, + typename WarpMmaOperator_::Shape, PartitionsK, AccumulatorFragmentIterator_, WarpTileIterator_, @@ -171,8 +175,9 @@ public: OutputOp const &output_op, ///< Output operator OutputTileIterator destination_iterator, ///< Tile iterator for destination AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - + OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + int64_t imag_stride_dest = 0, ///< Arguments required for planar complex case - not used in real-valued case + int64_t imag_stride_src = 0) { ///< typename OutputTileIterator::Fragment source_fragment; diff --git a/include/cutlass/epilogue/threadblock/epilogue_base.h b/include/cutlass/epilogue/threadblock/epilogue_base.h index 55843e26..a8a0dc49 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_base.h +++ b/include/cutlass/epilogue/threadblock/epilogue_base.h @@ -32,7 +32,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/cutlass.h" #include "cutlass/matrix_shape.h" @@ -58,7 +62,7 @@ namespace threadblock { /// Base class for epilogues defining warp-level template < typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) - typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + typename WarpShape_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) int PartitionsK, ///< Number of partitions of the K dimension typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM @@ -68,7 +72,7 @@ class EpilogueBase { public: using Shape = Shape_; - using WarpMmaOperator = WarpMmaOperator_; + using WarpShape = WarpShape_; static int const kPartitionsK = PartitionsK; using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; using WarpTileIterator = WarpTileIterator_; @@ -83,11 +87,10 @@ public: /// Accumulator element using ElementAccumulator = typename AccumulatorTile::Element; - /// Number of warps using WarpCount = gemm::GemmShape< - Shape::kM / WarpMmaOperator::Shape::kM, - Shape::kN / WarpMmaOperator::Shape::kN, + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, kPartitionsK >; @@ -144,24 +147,6 @@ public: storage.data(), Layout::packed({StorageShape::kRow, StorageShape::kColumn})); } - - CUTLASS_DEVICE - void debug_print() { - if (threadIdx.x == 0) { - - #pragma unroll 1 - for (int r = 0; r < Shape::kRow; ++r) { - - #pragma unroll 1 - for (int c = 0; c < Shape::kColumn; ++c) { - - printf("%d ", int(storage.data()[r * StorageShape::kColumn + c])); - } - printf("\n"); - } - } - __syncthreads(); - } }; protected: diff --git a/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h b/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h new file mode 100644 index 00000000..8362748e --- /dev/null +++ b/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h @@ -0,0 +1,397 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/array_planar_complex.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator for planar-complex output representations. +/// +/// Note, as with most CUTLASS components for planar complex, the template arguments describe +/// the underlying real data type. +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + typename OutputOp_, ///< Output operator + typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) +> +class EpiloguePlanarComplex { +public: + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + /// Output layout is always row-major + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = ArrayPlanarComplex< + typename WarpMmaOperator::FragmentC::Element, + WarpMmaOperator::FragmentC::kElements + >; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Shape of each warp-level operation + using WarpShape = typename WarpMmaOperator::Shape; + + /// Number of warps + using WarpCount = gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + kPartitionsK + >; + + /// Shared memory allocation + struct SharedStorage { + + // + // Type definitions + // + + /// Element type of shared memory + using Element = typename WarpTileIterator::Element; + + /// Tensor reference to shared memory allocation + using TensorRef = typename WarpTileIterator::TensorRef; + + /// Layout of shared memory allocation + using Layout = typename WarpTileIterator::Layout; + + /// Logical shape of the shared memory tile written to by all warps. + using Shape = MatrixShape< + WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK, + WarpCount::kN * WarpTileIterator::Shape::kColumn + >; + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = MatrixShape< + Shape::kRow + Padding::kRow, + Shape::kColumn + Padding::kColumn + >; + + static int const kImaginaryStride = StorageShape::kCount; + + // + // Data members + // + + AlignedBuffer storage; + + // + // Methods + // + + /// Returns a pointer to the shared memory buffer + CUTLASS_DEVICE + Element *data() { + return storage.data(); + } + + /// Returns a tensor reference to the shared memory buffer + CUTLASS_DEVICE + TensorRef reference() { + return TensorRef( + storage.data(), + Layout::packed({StorageShape::kRow, StorageShape::kColumn})); + } + }; + +private: + + // + // Data members + // + + SharedStorage &shared_storage_; + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + /// Stores a warp's fragment of accumulators to SMEM + WarpTileIterator warp_tile_iterator_; + +public: + + /// Constructor + CUTLASS_DEVICE + EpiloguePlanarComplex( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + shared_storage_(shared_storage), + shared_load_iterator_(shared_storage.reference(), thread_idx), + warp_tile_iterator_(shared_storage.reference(), lane_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: + // + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); + int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); + int warp_m = warp_mn % WarpCount::kM; + int warp_n = warp_mn / WarpCount::kM; + + MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; + + warp_tile_iterator_.add_tile_offset(warp_offset); + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator_real, ///< Tile iterator for destination + OutputTileIterator destination_iterator_imag, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator_real, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + OutputTileIterator source_iterator_imag) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + typename OutputTileIterator::Fragment source_fragment_real; + typename OutputTileIterator::Fragment source_fragment_imag; + + if (!output_op.is_source_needed()) { + source_iterator_real.clear_mask(); + source_iterator_imag.clear_mask(); + } + + source_fragment_real.clear(); + source_fragment_imag.clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator_real(accumulators.real); + AccumulatorFragmentIterator accum_fragment_iterator_imag(accumulators.imag); + + // + // Iterate over accumulator tile + // + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + source_iterator_real.load(source_fragment_real); + source_iterator_imag.load(source_fragment_imag); + + ++source_iterator_real; + ++source_iterator_imag; + + // + // Convert and store fragment + // + + __syncthreads(); + + typename AccumulatorFragmentIterator::Fragment accum_fragment_real; + typename AccumulatorFragmentIterator::Fragment accum_fragment_imag; + + accum_fragment_iterator_real.load(accum_fragment_real); + accum_fragment_iterator_imag.load(accum_fragment_imag); + + ++accum_fragment_iterator_real; + ++accum_fragment_iterator_imag; + + this->warp_tile_iterator_.store(accum_fragment_real); + this->warp_tile_iterator_.store_with_pointer_offset(accum_fragment_imag, SharedStorage::kImaginaryStride); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment_real[kPartitionsK]; + typename SharedLoadIterator::Fragment aligned_accum_fragment_imag[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment_real[0]); + shared_load_iterator_.load_with_pointer_offset(aligned_accum_fragment_imag[0], SharedStorage::kImaginaryStride); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + static_assert(kPartitionsK == 1, "Sliced-K not supported for planar complex at this time"); + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment_real; + typename OutputTileIterator::Fragment output_fragment_imag; + + apply_output_operator_( + output_fragment_real, + output_fragment_imag, + output_op, + aligned_accum_fragment_real[0], + aligned_accum_fragment_imag[0], + source_fragment_real, + source_fragment_imag); + + // + // Store the final result + // + + destination_iterator_real.store(output_fragment_real); + destination_iterator_imag.store(output_fragment_imag); + + ++destination_iterator_real; + ++destination_iterator_imag; + } + } + +private: + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment &output_fragment_real, + typename OutputTileIterator::Fragment &output_fragment_imag, + OutputOp const &output_op, ///< Output operator + typename SharedLoadIterator::Fragment const &aligned_accum_fragment_real, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment_imag, + typename OutputTileIterator::Fragment const &source_fragment_real, + typename OutputTileIterator::Fragment const &source_fragment_imag) { + + OutputAccessType *output_frag_real_ptr = + reinterpret_cast(&output_fragment_real); + + OutputAccessType *output_frag_imag_ptr = + reinterpret_cast(&output_fragment_imag); + + AccumulatorAccessType const *compute_frag_real_ptr = + reinterpret_cast(&aligned_accum_fragment_real); + + AccumulatorAccessType const *compute_frag_imag_ptr = + reinterpret_cast(&aligned_accum_fragment_imag); + + OutputAccessType const *source_frag_real_ptr = + reinterpret_cast(&source_fragment_real); + + OutputAccessType const *source_frag_imag_ptr = + reinterpret_cast(&source_fragment_imag); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + // Call the output operator + auto result_fragment = output_op( + make_ArrayPlanarComplex(compute_frag_real_ptr[i], compute_frag_imag_ptr[i]), + make_ArrayPlanarComplex(source_frag_real_ptr[i], source_frag_imag_ptr[i]) + ); + + output_frag_real_ptr[i] = result_fragment.real; + output_frag_imag_ptr[i] = result_fragment.imag; + } + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h index ba97f9cf..0a730ef1 100644 --- a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h +++ b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h @@ -85,9 +85,6 @@ class InterleavedEpilogue { using OutputTileIterator = OutputTileIterator_; using OutputOp = OutputOp_; - /// Output layout is always row-major - using Layout = layout::ColumnMajorInterleaved; - /// The complete warp-level accumulator tile using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; diff --git a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h index e4929cfe..fd28ac75 100644 --- a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h +++ b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h @@ -437,11 +437,10 @@ struct OutputTileOptimalThreadMap { /// - minimal address arithmetic /// - minimal predicate calculations /// -template struct InterleavedOutputTileThreadMap { using WarpCount = WarpCount_; - using MmaCount = MmaCount_; static int const kWarpSize = 32; static int const kThreads = Threads; @@ -460,7 +459,7 @@ struct InterleavedOutputTileThreadMap { // Output // - using Iterations = MmaCount; + using Iterations = Iterations_; using Delta = layout::PitchLinearShape; @@ -491,6 +490,67 @@ struct InterleavedOutputTileThreadMap { //////////////////////////////////////////////////////////////////////////////// +/// Template metaprogram for partitioning a 4D interleaved layout across warps +/// to achieve several performance objectives: +/// +/// - coalesced memory accesses in units of 64 Byte lines +/// - minimal address arithmetic +/// - minimal predicate calculations +/// +template +struct InterleavedConvOutputTileThreadMap { + using WarpCount = WarpCount_; + + static int const kWarpSize = 32; + static int const kThreads = Threads; + static int const kWarpCount = kThreads / kWarpSize; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + // + // Metaprogram computation + // + + struct Detail {}; + + // + // Output + // + + using Iterations = Iterations_; + + using Delta = MatrixShape; + + /// Initial offset function + CUTLASS_HOST_DEVICE + static MatrixCoord initial_offset(int thread_idx) { + int warp_idx = thread_idx / kWarpSize; + int lane_idx = thread_idx % kWarpSize; + + // Compute warp location + MatrixCoord warp_footprint{ + Delta::kRow * Iterations::kRow, + Delta::kColumn * Iterations::kColumn, + }; + + MatrixCoord warp_offset{warp_idx % WarpCount::kRow, + warp_idx / WarpCount::kRow}; + + // Compute per-lane offset + MatrixCoord thread_offset_in_warp{lane_idx / 4, + (lane_idx % 4) * kElementsPerAccess}; + + MatrixCoord thread_offset_in_threadblock_tile = + warp_footprint * warp_offset + thread_offset_in_warp; + + return thread_offset_in_threadblock_tile; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + } // namespace threadblock } // namespace epilogue } // namespace cutlass diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index 9b894170..486d16c7 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -36,9 +36,9 @@ #include "cutlass/numeric_types.h" #include "cutlass/array.h" #include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" #include "cutlass/matrix_shape.h" #include "cutlass/tensor_ref.h" - #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/epilogue/threadblock/output_tile_thread_map.h" @@ -107,16 +107,16 @@ public: // Data members // - Index stride; ///< stride in bytes between rows + LongIndex stride; ///< stride in bytes between rows - Index increment_row; ///< increment quantity (in bytes) to advance when moving between rows - Index increment_group; ///< increment quantity (in bytes) to advance when moving to the next group - Index increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster + LongIndex increment_row; ///< increment quantity (in bytes) to advance when moving between rows + LongIndex increment_group; ///< increment quantity (in bytes) to advance when moving to the next group + LongIndex increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster - Index advance_row; ///< amount to add to move to the next 'row' position - Index advance_group; ///< amount to add to move to the next 'group' position - Index advance_cluster; ///< amount to add to move to the next 'cluster' position - Index advance_tile; ///< amount to add to move to the next 'tile' + LongIndex advance_row; ///< amount to add to move to the next 'row' position + LongIndex advance_group; ///< amount to add to move to the next 'group' position + LongIndex advance_cluster; ///< amount to add to move to the next 'cluster' position + LongIndex advance_tile; ///< amount to add to move to the next 'tile' // // Methods @@ -125,7 +125,7 @@ public: CUTLASS_HOST_DEVICE Status initialize(Index stride_) { - stride = stride_; + stride = LongIndex(stride_); increment_row = stride * ThreadMap::Delta::kRow; @@ -261,8 +261,8 @@ public: // Initialize pointer byte_pointer_ = reinterpret_cast(pointer) + - thread_offset.row() * params_.stride + - thread_offset.column() * sizeof(AccessType) / kElementsPerAccess; + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; @@ -276,7 +276,7 @@ public: /// Loads a fragment from memory CUTLASS_DEVICE - void load(Fragment &frag) { + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { uint8_t *byte_pointer = byte_pointer_; AccessType *frag_ptr = reinterpret_cast(&frag); @@ -299,7 +299,7 @@ public: bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - AccessType *memory_pointer = reinterpret_cast(byte_pointer); + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { @@ -328,9 +328,15 @@ public: } } + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_byte_offset(frag, 0); + } + /// Stores a fragment to memory CUTLASS_DEVICE - void store(Fragment const &frag) { + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { uint8_t *byte_pointer = byte_pointer_; AccessType const *frag_ptr = reinterpret_cast(&frag); @@ -352,7 +358,7 @@ public: bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - AccessType *memory_pointer = reinterpret_cast(byte_pointer); + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { @@ -382,6 +388,12 @@ public: } } + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_byte_offset(frag, 0); + } + /// Advances to the next position to load or store CUTLASS_HOST_DEVICE PredicatedTileIterator &operator++() { @@ -440,6 +452,7 @@ public: }; //////////////////////////////////////////////////////////////////////////////// + /// Tile iterator used to load output tile from shared memory in epilogue. /// /// Satisfies: ReadableTileIterator | InterleavedPredicatedTileIterator | ForwardTileIterator @@ -447,7 +460,7 @@ public: template < typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) typename Element_, ///< Element data type - int InterleavedK ///< Number of Interleaved K + int InterleavedN ///< Number of Interleaved N > class InterleavedPredicatedTileIterator { public: @@ -455,7 +468,7 @@ public: using Element = Element_; - using Layout = layout::ColumnMajorInterleaved; + using Layout = layout::ColumnMajorInterleaved; using TensorRef = TensorRef; using ConstTensorRef = typename TensorRef::ConstTensorRef; @@ -483,10 +496,10 @@ public: // Data members // - Index stride; ///< stride in bytes between columns + LongIndex stride; ///< stride in bytes between columns - Index advance_row; ///< amount to add to move to the next 'row' position - Index advance_column; ///< amount to add to move to the next 'column' position + LongIndex advance_row; ///< amount to add to move to the next 'row' position + LongIndex advance_column; ///< amount to add to move to the next 'column' position // // Methods @@ -494,14 +507,16 @@ public: CUTLASS_HOST_DEVICE Status initialize(Index stride_) { - stride = stride_; + + stride = LongIndex(stride_); advance_row = ThreadMap::Delta::kContiguous * sizeof_bits::value / 8; - advance_column = - stride_ - ThreadMap::Iterations::kContiguous * kElementsPerAccess * - sizeof_bits::value * ThreadMap::kWarpSize / 8; + advance_column = LongIndex(stride_) - ThreadMap::Iterations::kContiguous * + kElementsPerAccess * + sizeof_bits::value * + ThreadMap::kWarpSize / 8; return Status::kSuccess; } @@ -602,10 +617,10 @@ public: ): params_(params) { TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + - TensorCoord(threadblock_offset.contiguous() * InterleavedK, - threadblock_offset.strided() / InterleavedK); + TensorCoord(threadblock_offset.contiguous() * InterleavedN, + threadblock_offset.strided() / InterleavedN); - extent_col_ = extent.strided() / InterleavedK; + extent_col_ = extent.strided() / InterleavedN; thread_start_col_ = thread_offset.strided(); // Initialize predicates @@ -613,13 +628,13 @@ public: for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { mask_.predicates[c] = ((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) < - (extent.contiguous() * InterleavedK)); + (extent.contiguous() * InterleavedN)); } // Initialize pointer byte_pointer_ = reinterpret_cast(pointer) + - thread_offset.strided() * params_.stride + - thread_offset.contiguous() * sizeof(AccessType) / kElementsPerAccess; + LongIndex(thread_offset.strided()) * LongIndex(params_.stride) + + LongIndex(thread_offset.contiguous()) * sizeof(AccessType) / kElementsPerAccess; // Initialize internal state counter iteration_contiguous_ = iteration_strided_ = 0; @@ -634,6 +649,7 @@ public: /// Loads a fragment from memory CUTLASS_DEVICE void load(Fragment &frag) { + uint8_t *byte_pointer = byte_pointer_; AccessType *frag_ptr = reinterpret_cast(&frag); AccessType *memory_pointer = reinterpret_cast(byte_pointer); diff --git a/include/cutlass/epilogue/warp/tile_iterator_simt.h b/include/cutlass/epilogue/warp/tile_iterator_simt.h index 2164a134..2bf92e01 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_simt.h +++ b/include/cutlass/epilogue/warp/tile_iterator_simt.h @@ -100,12 +100,28 @@ public: /// Padding quantity using Padding = MatrixShape< 0, - 4 * Policy::kElementsPerAccess>; + 4 * Policy::kElementsPerAccess +#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES + + 1 +#endif + >; private: +#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES /// Storage type for accessing memory - using AccessType = AlignedArray; + using AccessType = AlignedArray< + Element, + 1 + >; + +#else + /// Storage type for accessing memory + using AccessType = AlignedArray< + Element, + Policy::kElementsPerAccess + >; +#endif // // Data members @@ -130,18 +146,21 @@ public: unsigned lane_id ): pointer_(reinterpret_cast(ref.data())), - layout_(ref.stride()[0] / Policy::kElementsPerAccess) { + layout_(ref.stride()[0] / AccessType::kElements) { auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); MatrixCoord lane_offset = lane_layout.inverse(lane_id); - pointer_ += layout_(lane_offset); + pointer_ += layout_({ + lane_offset.row(), + lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) + }); } /// Adds a pointer offset CUTLASS_HOST_DEVICE TileIteratorSimt & add_pointer_offset(Index pointer_offset) { - pointer_ += pointer_offset / Policy::kElementsPerAccess; + pointer_ += pointer_offset / AccessType::kElements; return *this; } @@ -151,7 +170,7 @@ public: pointer_ += layout_({ tile_offset.row() * Shape::kRow, - (tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess) + (tile_offset.column() * Shape::kColumn / int(AccessType::kElements)) }); return *this; @@ -173,7 +192,7 @@ public: // de-vectorized stores using ScalarAccessType = AlignedArray; ScalarAccessType const *scalarFragPtr = reinterpret_cast(&frag); - ScalarAccessType *scalarPointer = reinterpret_cast(pointer_); + ScalarAccessType *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { @@ -187,7 +206,7 @@ public: AccessType const *frag_ptr = reinterpret_cast(&frag); CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { - pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn] = frag_ptr[n]; + pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n]; } #endif } @@ -206,7 +225,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { - frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn]; + frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)]; } } diff --git a/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h b/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h index 3984680f..a9ca2315 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h +++ b/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h @@ -187,7 +187,8 @@ public: int access = access_idx % 2; int ptr_offset = tile_idx * InterleavedTileShape::kN / Policy::kElementsPerAccess + - access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess + access; + access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess + + access + pointer_offset / Policy::kElementsPerAccess; int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx; @@ -219,7 +220,9 @@ public: int access_quad = access_idx / 2; int access = access_idx % 2; - int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta + access; + int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta + + access + pointer_offset / Policy::kElementsPerAccess; + int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx; frag_ptr[frag_idx] = pointer_[ptr_offset]; @@ -382,7 +385,7 @@ public: int ptr_row_offset = row_idx * 2; - int ptr_offset = layout_({ptr_row_offset, ptr_column_offset}); + int ptr_offset = layout_({ptr_row_offset, ptr_column_offset}) + pointer_offset / Policy::kElementsPerAccess; pointer_[ptr_offset] = frag_ptr[frag_idx]; } diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index a341922b..ebc821ed 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -25,7 +25,12 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif + #include "cutlass/cutlass.h" /** @@ -201,6 +206,18 @@ void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul, rem = src - (quo * div); } +/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b +CUTLASS_HOST_DEVICE +int round_up(int a, int b) { + return ((a + b - 1) / b) * b; +} + +/// Returns the ceiling of (a / b) +CUTLASS_HOST_DEVICE +int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + /****************************************************************************** * Min/Max ******************************************************************************/ diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 00eeff2d..f712e04a 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -1,4 +1,4 @@ -/*************************************************************************************************** + /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted @@ -1207,4 +1207,8 @@ struct multiply_add, Array, Array> { ///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index 55deea6e..c91aac20 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -192,7 +192,8 @@ template < OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, ElementAccumulator_>::EpilogueOutputOp, /// Threadblock-level swizzling operator - typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle, + typename ThreadblockSwizzle_ = + typename threadblock::GemmCohortThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages = DefaultGemmConfiguration 1) { - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.split_k_slices); - - return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); } - return 0; + return bytes; } /// Initializes GEMM state from arguments. @@ -426,6 +431,7 @@ public: params_.ref_B.reset(args.ref_B.non_const_ref().data()); params_.ref_C.reset(args.ref_C.non_const_ref().data()); params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; params_.semaphore = static_cast(workspace); return Status::kSuccess; @@ -560,6 +566,8 @@ class Gemm gemm_op; + + // + // Launch the GEMM operation on the device + // + + cutlass::Status status = gemm_op({ + {m, n, k}, // GemmCoord problem_size, + {A, lda}, // TensorRef ref_A, + {B, ldb}, // TensorRef ref_B, + {C, ldc}, // TensorRef ref_C, + {D, ldd}, // TensorRef ref_D, + {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params + }); + + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + > + class Gemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator +> +class GemmArray { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + using Operator = Operator_; + + /// Define the kernel + using DefaultGemmKernel = typename kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + false, + Operator, + false + >::GemmKernel; + + using GemmKernel = kernel::GemmArray; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + + ElementA const * const *ptr_A; + LayoutA layout_A; + + ElementB const * const *ptr_B; + LayoutB layout_B; + + ElementC const * const *ptr_C; + LayoutC layout_C; + + ElementC * const * ptr_D; + LayoutC layout_D; + + typename EpilogueOutputOp::Params epilogue; + int batch_count; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + ElementA const * const *ptr_A_, + LayoutA layout_A_, + ElementB const * const *ptr_B_, + LayoutB layout_B_, + ElementC const * const *ptr_C_, + LayoutC layout_C_, + ElementC * const * ptr_D_, + LayoutC layout_D_, + typename EpilogueOutputOp::Params epilogue_, + int batch_count_ + ): + problem_size(problem_size_), + ptr_A(ptr_A_), + layout_A(layout_A_), + ptr_B(ptr_B_), + layout_B(layout_B_), + ptr_C(ptr_C_), + layout_C(layout_C_), + ptr_D(ptr_D_), + layout_D(layout_D_), + epilogue(epilogue_), + batch_count(batch_count_) { } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + GemmArray() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (args.layout_A.stride(0) % kAlignmentA) { + return Status::kErrorMisalignedOperand; + } + + if (args.layout_B.stride(0) % kAlignmentB) { + return Status::kErrorMisalignedOperand; + } + + if (args.layout_C.stride(0) % kAlignmentC) { + return Status::kErrorMisalignedOperand; + } + + if (args.layout_D.stride(0) % kAlignmentC) { + return Status::kErrorMisalignedOperand; + } + + if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) || + (args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) || + (args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + return 0; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + args.batch_count, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ptr_A, + args.layout_A, + args.ptr_B, + args.layout_B, + args.ptr_C, + args.layout_C, + args.ptr_D, + args.layout_D, + args.epilogue, + args.batch_count + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + args.batch_count, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); + + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ptr_A, + args.layout_A, + args.ptr_B, + args.layout_B, + args.ptr_C, + args.layout_C, + args.ptr_D, + args.layout_D, + args.epilogue, + args.batch_count + }; + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = cudaFuncSetAttribute( + Kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Parital specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + typename Operator_ +> +class GemmArray< + ElementA_, + LayoutA_, + ElementB_, + LayoutB_, + ElementC_, + layout::ColumnMajor, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + AlignmentA, + AlignmentB, + Operator_ +> { +public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static int const kStages = Stages; + + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = false; + + // + using UnderlyingOperator = GemmArray< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA + >; + + using UnderlyingArguments = typename UnderlyingOperator::Arguments; + using GemmKernel = typename UnderlyingOperator::GemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + + ElementA const * const *ptr_A; + LayoutA layout_A; + + ElementB const * const *ptr_B; + LayoutB layout_B; + + ElementC const * const *ptr_C; + LayoutC layout_C; + + ElementC * const * ptr_D; + LayoutC layout_D; + + typename EpilogueOutputOp::Params epilogue; + int batch_count; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + ElementA const * const *ptr_A_, + LayoutA layout_A_, + ElementB const * const *ptr_B_, + LayoutB layout_B_, + ElementC const * const *ptr_C_, + LayoutC layout_C_, + ElementC * const * ptr_D_, + LayoutC layout_D_, + typename EpilogueOutputOp::Params epilogue_, + int batch_count_ + ): + problem_size(problem_size_), + ptr_A(ptr_A_), + layout_A(layout_A_), + ptr_B(ptr_B_), + layout_B(layout_B_), + ptr_C(ptr_C_), + layout_C(layout_C_), + ptr_D(ptr_D_), + layout_D(layout_D_), + epilogue(epilogue_), + batch_count(batch_count_) { } + }; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmArray() { } + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static UnderlyingArguments to_underlying_arguments(Arguments const &args) { + + GemmCoord problem_size{ + args.problem_size.n(), + args.problem_size.m(), + args.problem_size.k() + }; + + return UnderlyingArguments( + problem_size, + args.ptr_B, + args.layout_B.stride(), + args.ptr_A, + args.layout_A.stride(), + args.ptr_C, + args.layout_C.stride(), + args.ptr_D, + args.layout_D.stride(), + args.epilogue, + args.batch_count + ); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index d2090e96..79fd4292 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -370,8 +370,8 @@ public: cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( args.problem_size, - args.batch_count, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); // Initialize the Params structure params_ = typename GemmKernel::Params{ diff --git a/include/cutlass/gemm/device/gemm_complex.h b/include/cutlass/gemm/device/gemm_complex.h index 5b0dea3b..2e6eb1b0 100644 --- a/include/cutlass/gemm/device/gemm_complex.h +++ b/include/cutlass/gemm/device/gemm_complex.h @@ -201,6 +201,8 @@ template < ComplexTransform TransformA = ComplexTransform::kNone, /// Complex elementwise transformation on B operand ComplexTransform TransformB = ComplexTransform::kNone, + /// Multiply-add operator + typename Operator_ = arch::OpMultiplyAddComplex, /// If true, kernel supports split-K with serial reduction bool SplitKSerial = false > @@ -228,7 +230,11 @@ class GemmComplex { static int const kStages = Stages; static ComplexTransform const kTransformA = TransformA; static ComplexTransform const kTransformB = TransformB; + using Operator = Operator_; static bool const kSplitKSerial = SplitKSerial; + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + static int const kAlignmentC = EpilogueOutputOp::kCount; /// Define the kernel using GemmKernel = typename kernel::DefaultGemmComplex< @@ -249,6 +255,7 @@ class GemmComplex { kStages, kTransformA, kTransformB, + Operator, kSplitKSerial >::GemmKernel; @@ -498,6 +505,8 @@ template < ComplexTransform TransformA, /// Complex elementwise transformation on B operand ComplexTransform TransformB, + /// Multiply-add operator + typename Operator_, /// If true, kernel supports split-K as a serial reduction bool SplitKSerial > @@ -519,6 +528,7 @@ class GemmComplex< Stages, TransformA, TransformB, + Operator_, SplitKSerial > { public: @@ -542,6 +552,7 @@ public: using EpilogueOutputOp = EpilogueOutputOp_; using ThreadblockSwizzle = ThreadblockSwizzle_; static int const kStages = Stages; + using Operator = Operator_; static bool const kSplitKSerial = SplitKSerial; using UnderlyingOperator = GemmComplex< @@ -562,8 +573,15 @@ public: Stages, TransformA, TransformB, + Operator, SplitKSerial >; + + static int const kAlignmentA = UnderlyingOperator::kAlignmentB; + static int const kAlignmentB = UnderlyingOperator::kAlignmentA; + static int const kAlignmentC = UnderlyingOperator::kAlignmentC; + static ComplexTransform const kTransformA = UnderlyingOperator::kTransformB; + static ComplexTransform const kTransformB = UnderlyingOperator::kTransformA; using UnderlyingArguments = typename UnderlyingOperator::Arguments; using GemmKernel = typename UnderlyingOperator::GemmKernel; diff --git a/include/cutlass/gemm/device/gemm_universal.h b/include/cutlass/gemm/device/gemm_universal.h new file mode 100644 index 00000000..4b57fa0d --- /dev/null +++ b/include/cutlass/gemm/device/gemm_universal.h @@ -0,0 +1,372 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone +> +class GemmUniversal : + GemmUniversalBase< + typename kernel::DefaultGemmUniversal< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmUniversal< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Parital specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB> +class GemmUniversal { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmUniversal< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator, + kTransformB, + kTransformA + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversal() { } + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h new file mode 100644 index 00000000..070914f2 --- /dev/null +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalAdapter { +public: + + using GemmKernel = GemmKernel_; + + static_assert(std::is_same::value, + "Universal adapter expects the kernel to be row-major and transposes its arguments."); + + using ThreadblockShape = typename GemmKernel::Mma::Shape; + using WarpShape = typename GemmKernel::WarpShape; + using InstructionShape = typename GemmKernel::InstructionShape; + + using OperatorClass = typename GemmKernel::OperatorClass; + using ArchTag = typename GemmKernel::ArchTag; + + // Type, layout, and complex transform deliberately exchanged with B + using ElementA = typename GemmKernel::ElementB; + using LayoutA = typename layout::LayoutTranspose::type; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformB; + static int const kAlignmentA = GemmKernel::kAlignmentA; + + // Type, layout, and complex transform deliberately exchanged with A + using ElementB = typename GemmKernel::ElementA; + using LayoutB = typename layout::LayoutTranspose::type; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformA; + static int const kAlignmentB = GemmKernel::kAlignmentB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = cutlass::layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = GemmKernel::kAlignmentC; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + static int const kStages = GemmKernel::Mma::kStages; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + using UnderlyingOperator = GemmUniversalBase; + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversalAdapter() { } + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h new file mode 100644 index 00000000..de0ee183 --- /dev/null +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -0,0 +1,339 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template +class GemmUniversalBase { +public: + + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + +protected: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +protected: + + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord &grid_tiled_shape, int &gemm_k_size, Arguments const &args) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + +public: + + /// Constructs the GEMM. + GemmUniversalBase() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + // Split-K parallel always requires a temporary workspace + workspace_bytes = + sizeof(ElementC) * + size_t(args.batch_stride_D) * + size_t(grid_tiled_shape.k()); + } + else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { + + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + return threadblock_swizzle.get_grid_shape(grid_tiled_shape); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size <= (48 << 10)) { + + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + Kernel, + GemmKernel::kThreadCount, + smem_size); + + if (result == cudaSuccess) { + return max_active_blocks; + } + } + else { + + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + Kernel, + GemmKernel::kThreadCount, + 0); + + if (result != cudaSuccess) { + return -1; + } + + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + return std::min(max_active_blocks, smem_capacity / smem_size); + } + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes) { + + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params( + args, + grid_tiled_shape, + gemm_k_size, + static_cast(workspace) + ); + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = cudaFuncSetAttribute( + Kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/gemm.h b/include/cutlass/gemm/gemm.h index 3a18a2b6..011e03c9 100644 --- a/include/cutlass/gemm/gemm.h +++ b/include/cutlass/gemm/gemm.h @@ -394,6 +394,15 @@ struct BatchedGemmCoord : public Coord<4, int> { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class GemmUniversalMode { + kGemm, + kGemmSplitKParallel, + kBatched, + kArray +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace gemm diff --git a/include/cutlass/gemm/kernel/default_gemm_complex.h b/include/cutlass/gemm/kernel/default_gemm_complex.h new file mode 100644 index 00000000..a9ef4e31 --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_complex.h @@ -0,0 +1,118 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Multiply-add operator + typename Operator, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial +> +struct DefaultGemmComplex; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h b/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h new file mode 100644 index 00000000..3664fece --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/kernel/gemm_planar_complex.h" +#include "cutlass/gemm/kernel/gemm_planar_complex_array.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_planar_complex.h" +#include "cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Math operation performed by GEMM (e.g. arch::OpMultiplyAdd) + typename Operator, + /// Conditional enabling to switch between stages + typename Enable = void + > +struct DefaultGemmPlanarComplexUniversal; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for pipelined mainloop +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator + > +struct DefaultGemmPlanarComplexUniversal< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator, + typename std::enable_if<(Stages <= 2)>::type +> { + + /// Define planar complex valued variants instead + using Mma = typename gemm::threadblock::DefaultMmaPlanarComplexPipelined< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + TransformA, + TransformB, + Operator + >::ThreadblockMma; + + /// Planar complex epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpiloguePlanarComplex< + ThreadblockShape, + typename Mma::Policy::Operator, + OperatorClass, + ArchTag, + ThreadblockShape::kK / WarpShape::kK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmPlanarComplex< + Mma, + Epilogue, + ThreadblockSwizzle + >; + + // Array variant + using GemmArrayKernel = kernel::GemmPlanarComplexArray< + Mma, + Epilogue, + ThreadblockSwizzle + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_universal.h b/include/cutlass/gemm/kernel/default_gemm_universal.h new file mode 100644 index 00000000..23db577c --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_universal.h @@ -0,0 +1,308 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// + typename Enable = void + > +struct DefaultGemmUniversal; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator> +struct DefaultGemmUniversal< + ElementA, + LayoutA, + ComplexTransform::kNone, // transform A + kAlignmentA, + ElementB, + LayoutB, + ComplexTransform::kNone, // transform B + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator, + typename std::enable_if< ! cutlass::is_complex::value>::type +> { + + using DefaultGemmKernel = typename kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + true, + Operator, + false + >::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmUniversal< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + ThreadblockSwizzle + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Complex-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator + > +struct DefaultGemmUniversal< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator, + typename std::enable_if::value>::type +> { + + using DefaultGemmKernel = typename kernel::DefaultGemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + TransformA, + TransformB, + Operator, + false + >::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmUniversal< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + ThreadblockSwizzle + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemv.h b/include/cutlass/gemm/kernel/default_gemv.h old mode 100644 new mode 100755 diff --git a/include/cutlass/gemm/kernel/gemm.h b/include/cutlass/gemm/kernel/gemm.h index 2220465e..36cf6731 100644 --- a/include/cutlass/gemm/kernel/gemm.h +++ b/include/cutlass/gemm/kernel/gemm.h @@ -83,7 +83,7 @@ struct Gemm { // CUTLASS_HOST_DEVICE - Params() { } + Params(): semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } CUTLASS_HOST_DEVICE Params( @@ -94,7 +94,7 @@ struct Gemm { typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op = typename OutputOp::Params(), - int *semaphore = nullptr + int *workspace = nullptr ): problem_size(problem_size), grid_tiled_shape(grid_tiled_shape), @@ -106,13 +106,14 @@ struct Gemm { ref_C(ref_C), params_D(ref_D.layout()), ref_D(ref_D), - output_op(output_op), - semaphore(semaphore) { + output_op(output_op) { int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; } }; @@ -220,7 +221,9 @@ struct Gemm { thread_idx, tb_offset_B); - int warp_idx = threadIdx.x / 32; + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/gemm_array.h b/include/cutlass/gemm/kernel/gemm_array.h new file mode 100644 index 00000000..30ff1d30 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_array.h @@ -0,0 +1,253 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmArray { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::Element const * const * ptr_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::Element const * const * ptr_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Element const * const * ptr_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::Element * const * ptr_D; + int64_t stride_D; + typename OutputOp::Params epilogue; + int batch_count; + int gemm_k_iterations; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size_, + cutlass::gemm::GemmCoord const & grid_tiled_shape_, + typename Mma::IteratorA::Element const * const * ptr_A_, + typename Mma::IteratorA::Layout layout_A, + typename Mma::IteratorB::Element const * const * ptr_B_, + typename Mma::IteratorB::Layout layout_B, + typename Epilogue::OutputTileIterator::Element const * const * ptr_C_, + typename Epilogue::OutputTileIterator::Layout layout_C, + typename Epilogue::OutputTileIterator::Element * const * ptr_D_, + typename Epilogue::OutputTileIterator::Layout layout_D, + typename OutputOp::Params epilogue_, + int batch_count_ + ): + problem_size(problem_size_), + grid_tiled_shape(grid_tiled_shape_), + params_A(layout_A), + ptr_A(ptr_A_), + params_B(layout_B), + ptr_B(ptr_B_), + params_C(layout_C), + ptr_C(ptr_C_), + params_D(layout_D), + ptr_D(ptr_D_), + epilogue(epilogue_), + batch_count(batch_count_), + gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) { + + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmArray() { } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + + // Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension + for (int batch_idx = threadblock_swizzle.get_batch_idx(); + batch_idx < params.batch_count; + batch_idx += gridDim.z) { + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + 0 + }; + + cutlass::MatrixCoord tb_offset_B{ + 0, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + const_cast(params.ptr_A[batch_idx]), + params.problem_size.mk(), + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + const_cast(params.ptr_B[batch_idx]), + params.problem_size.kn(), + thread_idx, + tb_offset_B); + + // + // Main loop + // + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + OutputOp output_op(params.epilogue); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + const_cast(params.ptr_C[batch_idx]), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ptr_D[batch_idx], + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // run efficient epilogue + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + diff --git a/include/cutlass/gemm/kernel/gemm_batched.h b/include/cutlass/gemm/kernel/gemm_batched.h index 68a5587f..8bf4354a 100644 --- a/include/cutlass/gemm/kernel/gemm_batched.h +++ b/include/cutlass/gemm/kernel/gemm_batched.h @@ -193,8 +193,10 @@ struct GemmBatched { // Main loop // - // Construct thread-scoped matrix multiply - int warp_idx = threadIdx.x / 32; + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex.h b/include/cutlass/gemm/kernel/gemm_planar_complex.h new file mode 100644 index 00000000..3d975bb2 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_planar_complex.h @@ -0,0 +1,693 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmPlanarComplex { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using Operator = typename Mma::Operator; + using ArchTag = typename Mma::ArchTag; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value); + + // + // Additional types needed for reflection + // + + using ElementAccumulator = typename Mma::Policy::Operator::ElementC; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::Shape; + + static int const kStages = Mma::kStages; + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + // + // Arguments structure + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A_real; + void const * ptr_A_imag; + + void const * ptr_B_real; + void const * ptr_B_imag; + + void const * ptr_C_real; + void const * ptr_C_imag; + + void * ptr_D_real; + void * ptr_D_imag; + + int lda_real; + int lda_imag; + int ldb_real; + int ldb_imag; + int ldc_real; + int ldc_imag; + int ldd_real; + int ldd_imag; + + int64_t batch_stride_A; + int64_t batch_stride_A_imag; + int64_t batch_stride_B; + int64_t batch_stride_B_imag; + int64_t batch_stride_C; + int64_t batch_stride_C_imag; + int64_t batch_stride_D; + int64_t batch_stride_D_imag; + + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A_real(nullptr), + ptr_A_imag(nullptr), + ptr_B_real(nullptr), + ptr_B_imag(nullptr), + ptr_C_real(nullptr), + ptr_C_imag(nullptr), + ptr_D_real(nullptr), + ptr_D_imag(nullptr) + { } + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A_real, + void const * ptr_A_imag, + void const * ptr_B_real, + void const * ptr_B_imag, + void const * ptr_C_real, + void const * ptr_C_imag, + void * ptr_D_real, + void * ptr_D_imag, + int lda_real, + int lda_imag, + int ldb_real, + int ldb_imag, + int ldc_real, + int ldc_imag, + int ldd_real, + int ldd_imag, + int64_t batch_stride_A = 0, + int64_t batch_stride_A_imag = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_B_imag = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_C_imag = 0, + int64_t batch_stride_D = 0, + int64_t batch_stride_D_imag = 0 + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A_real(ptr_A_real), + ptr_A_imag(ptr_A_imag), + ptr_B_real(ptr_B_real), + ptr_B_imag(ptr_B_imag), + ptr_C_real(ptr_C_real), + ptr_C_imag(ptr_C_imag), + ptr_D_real(ptr_D_real), + ptr_D_imag(ptr_D_imag), + lda_real(lda_real), + lda_imag(lda_imag), + ldb_real(ldb_real), + ldb_imag(ldb_imag), + ldc_real(ldc_real), + ldc_imag(ldc_imag), + ldd_real(ldd_real), + ldd_imag(ldd_imag), + batch_stride_A(batch_stride_A), + batch_stride_A_imag(batch_stride_A_imag), + batch_stride_B(batch_stride_B), + batch_stride_B_imag(batch_stride_B_imag), + batch_stride_C(batch_stride_C), + batch_stride_C_imag(batch_stride_C_imag), + batch_stride_D(batch_stride_D), + batch_stride_D_imag(batch_stride_D_imag) { + + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A_real, args.ptr_B_real); + std::swap(args.ptr_A_imag, args.ptr_B_imag); + std::swap(args.lda_real, args.ldb_real); + std::swap(args.lda_imag, args.ldb_imag); + std::swap(args.batch_stride_A, args.batch_stride_B); + std::swap(args.batch_stride_A_imag, args.batch_stride_B_imag); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + + typename Mma::IteratorA::Params params_A_real; + typename Mma::IteratorA::Params params_A_imag; + typename Mma::IteratorB::Params params_B_real; + typename Mma::IteratorB::Params params_B_imag; + typename Epilogue::OutputTileIterator::Params params_C_real; + typename Epilogue::OutputTileIterator::Params params_C_imag; + typename Epilogue::OutputTileIterator::Params params_D_real; + typename Epilogue::OutputTileIterator::Params params_D_imag; + + typename EpilogueOutputOp::Params output_op; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A_real; + void * ptr_A_imag; + void * ptr_B_real; + void * ptr_B_imag; + void * ptr_C_real; + void * ptr_C_imag; + void * ptr_D_real; + void * ptr_D_imag; + + int64_t batch_stride_A; + int64_t batch_stride_A_imag; + int64_t batch_stride_B; + int64_t batch_stride_B_imag; + int64_t batch_stride_C; + int64_t batch_stride_C_imag; + int64_t batch_stride_D; + int64_t batch_stride_D_imag; + + int *semaphore; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A_real(nullptr), + ptr_A_imag(nullptr), + ptr_B_real(nullptr), + ptr_B_imag(nullptr), + ptr_C_real(nullptr), + ptr_C_imag(nullptr), + ptr_D_real(nullptr), + ptr_D_imag(nullptr), + batch_stride_A(0), + batch_stride_A_imag(0), + batch_stride_B(0), + batch_stride_B_imag(0), + batch_stride_C(0), + batch_stride_C_imag(0), + batch_stride_D(0), + batch_stride_D_imag(0), + semaphore(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size, + void *workspace = nullptr + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + params_A_real(args.lda_real), + params_A_imag(args.lda_imag), + params_B_real(args.ldb_real), + params_B_imag(args.ldb_imag), + params_C_real(args.ldc_real), + params_C_imag(args.ldc_imag), + params_D_real(args.ldd_real), + params_D_imag(args.ldd_imag), + output_op(args.epilogue), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(gemm_k_size), + ptr_A_real(const_cast(args.ptr_A_real)), + ptr_A_imag(const_cast(args.ptr_A_imag)), + ptr_B_real(const_cast(args.ptr_B_real)), + ptr_B_imag(const_cast(args.ptr_B_imag)), + ptr_C_real(const_cast(args.ptr_C_real)), + ptr_C_imag(const_cast(args.ptr_C_imag)), + ptr_D_real(args.ptr_D_real), + ptr_D_imag(args.ptr_D_imag), + batch_stride_A(args.batch_stride_A), + batch_stride_A_imag(args.batch_stride_A_imag), + batch_stride_B(args.batch_stride_B), + batch_stride_B_imag(args.batch_stride_B_imag), + batch_stride_C(args.batch_stride_C), + batch_stride_C_imag(args.batch_stride_C_imag), + batch_stride_D(args.batch_stride_D), + batch_stride_D_imag(args.batch_stride_D_imag), + semaphore(static_cast(workspace)) { + + } + + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_A_real = const_cast(args.ptr_A_real); + ptr_A_imag = const_cast(args.ptr_A_imag); + + ptr_B_real = const_cast(args.ptr_B_real); + ptr_B_imag = const_cast(args.ptr_B_imag); + + ptr_C_real = const_cast(args.ptr_C_real); + ptr_C_imag = const_cast(args.ptr_C_imag); + + ptr_D_real = const_cast(args.ptr_D_real); + ptr_D_imag = const_cast(args.ptr_D_imag); + + batch_stride_A = args.batch_stride_A; + batch_stride_A_imag = args.batch_stride_A_imag; + batch_stride_B = args.batch_stride_B; + batch_stride_B_imag = args.batch_stride_B_imag; + batch_stride_C = args.batch_stride_C; + batch_stride_C_imag = args.batch_stride_C_imag; + batch_stride_D = args.batch_stride_D; + batch_stride_D_imag = args.batch_stride_D_imag; + + output_op = args.epilogue; + + semaphore = static_cast(workspace); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmPlanarComplex() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const &args) { + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) || + (args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) || + (args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A_real = static_cast(params.ptr_A_real); + ElementA *ptr_A_imag = static_cast(params.ptr_A_imag); + + ElementB *ptr_B_real = static_cast(params.ptr_B_real); + ElementB *ptr_B_imag = static_cast(params.ptr_B_imag); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A; + ptr_A_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A_imag; + ptr_B_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B; + ptr_B_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B_imag; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A_real = static_cast(params.ptr_A_real)[threadblock_tile_offset.k()]; + ptr_A_imag = static_cast(params.ptr_A_imag)[threadblock_tile_offset.k()]; + ptr_B_real = static_cast(params.ptr_B_real)[threadblock_tile_offset.k()]; + ptr_B_imag = static_cast(params.ptr_B_imag)[threadblock_tile_offset.k()]; + } + + __syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A_real( + params.params_A_real, + ptr_A_real, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorA iterator_A_imag( + params.params_A_imag, + ptr_A_imag, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B_real( + params.params_B_real, + ptr_B_real, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + typename Mma::IteratorB iterator_B_imag( + params.params_B_imag, + ptr_B_imag, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A_real, + iterator_A_imag, + iterator_B_real, + iterator_B_imag, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C_real = static_cast(params.ptr_C_real); + ElementC *ptr_C_imag = static_cast(params.ptr_C_imag); + ElementC *ptr_D_real = static_cast(params.ptr_D_real); + ElementC *ptr_D_imag = static_cast(params.ptr_D_imag); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D_real += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D_imag += threadblock_tile_offset.k() * params.batch_stride_D_imag; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C; + ptr_C_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C_imag; + ptr_D_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D; + ptr_D_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D_imag; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C_real = static_cast(params.ptr_C_real)[threadblock_tile_offset.k()]; + ptr_C_imag = static_cast(params.ptr_C_imag)[threadblock_tile_offset.k()]; + ptr_D_real = static_cast(params.ptr_D_real)[threadblock_tile_offset.k()]; + ptr_D_imag = static_cast(params.ptr_D_imag)[threadblock_tile_offset.k()]; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C_real( + params.params_C_real, + ptr_C_real, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + typename Epilogue::OutputTileIterator iterator_C_imag( + params.params_C_imag, + ptr_C_imag, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D_real( + params.params_D_real, + ptr_D_real, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + typename Epilogue::OutputTileIterator iterator_D_imag( + params.params_D_imag, + ptr_D_imag, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // + // Construct epilogue + // + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C_real = iterator_D_real; + iterator_C_imag = iterator_D_imag; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D_real, + iterator_D_imag, + accumulators, + iterator_C_real, + iterator_C_imag); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h new file mode 100644 index 00000000..efb500b2 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h @@ -0,0 +1,583 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmPlanarComplexArray { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using Operator = typename Mma::Operator; + using ArchTag = typename Mma::ArchTag; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value); + + // + // Additional types needed for reflection + // + + using ElementAccumulator = typename Mma::Policy::Operator::ElementC; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::Shape; + + static int const kStages = Mma::kStages; + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + // + // Arguments structure + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + int const *ptr_M; + int const *ptr_N; + int const *ptr_K; + + void const * const * ptr_A_real; + void const * const * ptr_A_imag; + + void const * const * ptr_B_real; + void const * const * ptr_B_imag; + + void const * const * ptr_C_real; + void const * const * ptr_C_imag; + + void * const * ptr_D_real; + void * const * ptr_D_imag; + + int lda_real; + int lda_imag; + int ldb_real; + int ldb_imag; + int ldc_real; + int ldc_imag; + int ldd_real; + int ldd_imag; + + int64_t batch_stride_D; // unused + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kArray), + batch_count(1), + ptr_M(nullptr), + ptr_N(nullptr), + ptr_K(nullptr), + ptr_A_real(nullptr), + ptr_A_imag(nullptr), + ptr_B_real(nullptr), + ptr_B_imag(nullptr), + ptr_C_real(nullptr), + ptr_C_imag(nullptr), + ptr_D_real(nullptr), + ptr_D_imag(nullptr), + batch_stride_D(0) + { } + + /// constructs an arguments structure + Arguments( + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + int const *ptr_M, + int const *ptr_N, + int const *ptr_K, + void const * const * ptr_A_real, + void const * const * ptr_A_imag, + void const * const * ptr_B_real, + void const * const * ptr_B_imag, + void const * const * ptr_C_real, + void const * const * ptr_C_imag, + void * const * ptr_D_real, + void * const * ptr_D_imag, + int lda_real, + int lda_imag, + int ldb_real, + int ldb_imag, + int ldc_real, + int ldc_imag, + int ldd_real, + int ldd_imag + ): + mode(GemmUniversalMode::kArray), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_M(ptr_M), + ptr_N(ptr_N), + ptr_K(ptr_K), + ptr_A_real(ptr_A_real), + ptr_A_imag(ptr_A_imag), + ptr_B_real(ptr_B_real), + ptr_B_imag(ptr_B_imag), + ptr_C_real(ptr_C_real), + ptr_C_imag(ptr_C_imag), + ptr_D_real(ptr_D_real), + ptr_D_imag(ptr_D_imag), + lda_real(lda_real), + lda_imag(lda_imag), + ldb_real(ldb_real), + ldb_imag(ldb_imag), + ldc_real(ldc_real), + ldc_imag(ldc_imag), + ldd_real(ldd_real), + ldd_imag(ldd_imag), + batch_stride_D(0) { + + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_M, args.ptr_N); + std::swap(args.ptr_A_real, args.ptr_B_real); + std::swap(args.ptr_A_imag, args.ptr_B_imag); + std::swap(args.lda_real, args.ldb_real); + std::swap(args.lda_imag, args.ldb_imag); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + + typename Mma::IteratorA::Params params_A_real; + typename Mma::IteratorA::Params params_A_imag; + typename Mma::IteratorB::Params params_B_real; + typename Mma::IteratorB::Params params_B_imag; + typename Epilogue::OutputTileIterator::Params params_C_real; + typename Epilogue::OutputTileIterator::Params params_C_imag; + typename Epilogue::OutputTileIterator::Params params_D_real; + typename Epilogue::OutputTileIterator::Params params_D_imag; + + typename EpilogueOutputOp::Params output_op; + + int batch_count; + + int const *ptr_M; + int const *ptr_N; + int const *ptr_K; + + void const * const * ptr_A_real; + void const * const * ptr_A_imag; + void const * const * ptr_B_real; + void const * const * ptr_B_imag; + void const * const * ptr_C_real; + void const * const * ptr_C_imag; + void * const * ptr_D_real; + void * const * ptr_D_imag; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + batch_count(0), + ptr_M(nullptr), + ptr_N(nullptr), + ptr_K(nullptr), + ptr_A_real(nullptr), + ptr_A_imag(nullptr), + ptr_B_real(nullptr), + ptr_B_imag(nullptr), + ptr_C_real(nullptr), + ptr_C_imag(nullptr), + ptr_D_real(nullptr), + ptr_D_imag(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size = 0, // ignored + void *workspace = nullptr // ignored + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + ptr_M(args.ptr_M), + ptr_N(args.ptr_N), + ptr_K(args.ptr_K), + params_A_real(args.lda_real), + params_A_imag(args.lda_imag), + params_B_real(args.ldb_real), + params_B_imag(args.ldb_imag), + params_C_real(args.ldc_real), + params_C_imag(args.ldc_imag), + params_D_real(args.ldd_real), + params_D_imag(args.ldd_imag), + output_op(args.epilogue), + batch_count(args.batch_count), + ptr_A_real(args.ptr_A_real), + ptr_A_imag(args.ptr_A_imag), + ptr_B_real(args.ptr_B_real), + ptr_B_imag(args.ptr_B_imag), + ptr_C_real(args.ptr_C_real), + ptr_C_imag(args.ptr_C_imag), + ptr_D_real(args.ptr_D_real), + ptr_D_imag(args.ptr_D_imag) { + + } + + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_M = args.ptr_M; + ptr_N = args.ptr_N; + ptr_K = args.ptr_K; + + ptr_A_real = args.ptr_A_real; + ptr_A_imag = args.ptr_A_imag; + + ptr_B_real = args.ptr_B_real; + ptr_B_imag = args.ptr_B_imag; + + ptr_C_real = args.ptr_C_real; + ptr_C_imag = args.ptr_C_imag; + + ptr_D_real = args.ptr_D_real; + ptr_D_imag = args.ptr_D_imag; + + output_op = args.epilogue; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmPlanarComplexArray() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const &args) { + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) || + (args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) || + (args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + int batch_idx = threadblock_tile_offset.k(); + + int problem_size_m = params.problem_size.m(); + int problem_size_n = params.problem_size.n(); + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A_real = static_cast(const_cast(params.ptr_A_real[batch_idx])); + ElementA *ptr_A_imag = static_cast(const_cast(params.ptr_A_imag[batch_idx])); + + ElementB *ptr_B_real = static_cast(const_cast(params.ptr_B_real[batch_idx])); + ElementB *ptr_B_imag = static_cast(const_cast(params.ptr_B_imag[batch_idx])); + + // + // If pointers for problem sizes are specified, these are loaded from global memory + // + + if (params.ptr_M) { + problem_size_m = params.ptr_M[batch_idx]; + } + + if (params.ptr_N) { + problem_size_n = params.ptr_N[batch_idx]; + } + + if (params.ptr_K) { + problem_size_k = params.ptr_K[batch_idx]; + } + + int const kBlockCountM = (problem_size_m + Mma::Shape::kM - 1) / Mma::Shape::kM; + int const kBlockCountN = (problem_size_n + Mma::Shape::kN - 1) / Mma::Shape::kN; + + int const kGemmKIterations = (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // + // Each threadblock loops over the logical problem size which the kernel may have discovered + // after the grid is launched. + // + + CUTLASS_PRAGMA_NO_UNROLL + for (int block_m = threadblock_tile_offset.m(); + block_m < kBlockCountM; + block_m += params.grid_tiled_shape.m()) { + + CUTLASS_PRAGMA_NO_UNROLL + for (int block_n = threadblock_tile_offset.n(); + block_n < kBlockCountN; + block_n += params.grid_tiled_shape.n()) { + + // + // Compute indices within threadblock and warp. + // + int thread_idx = threadIdx.x; + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Proceed with regular GEMM logic. + // + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ block_m * Mma::Shape::kM, 0}; + cutlass::MatrixCoord tb_offset_B{ 0, block_n * Mma::Shape::kN }; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A_real( + params.params_A_real, + ptr_A_real, + {problem_size_m, problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorA iterator_A_imag( + params.params_A_imag, + ptr_A_imag, + {problem_size_m, problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B_real( + params.params_B_real, + ptr_B_real, + {problem_size_k, problem_size_n}, + thread_idx, + tb_offset_B); + + typename Mma::IteratorB iterator_B_imag( + params.params_B_imag, + ptr_B_imag, + {problem_size_k, problem_size_n}, + thread_idx, + tb_offset_B); + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma( + kGemmKIterations, + accumulators, + iterator_A_real, + iterator_A_imag, + iterator_B_real, + iterator_B_imag, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + //assume identity swizzle + MatrixCoord threadblock_offset( + block_m * Mma::Shape::kM, + block_n * Mma::Shape::kN + ); + + ElementC *ptr_C_real = static_cast(const_cast(params.ptr_C_real[batch_idx])); + ElementC *ptr_C_imag = static_cast(const_cast(params.ptr_C_imag[batch_idx])); + ElementC *ptr_D_real = static_cast(params.ptr_D_real[batch_idx]); + ElementC *ptr_D_imag = static_cast(params.ptr_D_imag[batch_idx]); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C_real( + params.params_C_real, + ptr_C_real, + {problem_size_m, problem_size_n}, + thread_idx, + threadblock_offset + ); + + typename Epilogue::OutputTileIterator iterator_C_imag( + params.params_C_imag, + ptr_C_imag, + {problem_size_m, problem_size_n}, + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D_real( + params.params_D_real, + ptr_D_real, + {problem_size_m, problem_size_n}, + thread_idx, + threadblock_offset + ); + + typename Epilogue::OutputTileIterator iterator_D_imag( + params.params_D_imag, + ptr_D_imag, + {problem_size_m, problem_size_n}, + thread_idx, + threadblock_offset + ); + + // + // Construct epilogue + // + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D_real, + iterator_D_imag, + accumulators, + iterator_C_real, + iterator_C_imag); + + + } // for block_n + } // for block_m + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h new file mode 100644 index 00000000..11831d8d --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -0,0 +1,537 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmUniversal { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::Shape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + + int lda; + int ldb; + int ldc; + int ldd; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int lda, + int ldb, + int ldc, + int ldd + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { + + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.batch_stride_A, args.batch_stride_B); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + + typename EpilogueOutputOp::Params output_op; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + + int *semaphore; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + params_A(0), + params_B(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + batch_stride_A(0), + batch_stride_B(0), + batch_stride_C(0), + batch_stride_D(0), + semaphore(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size, + void *workspace = nullptr + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + params_D(args.ldd), + output_op(args.epilogue), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(gemm_k_size), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + semaphore(static_cast(workspace)) { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + + output_op = args.epilogue; + + semaphore = static_cast(workspace); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmUniversal() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || + (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || + (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + + __syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemv_batched_strided.h b/include/cutlass/gemm/kernel/gemv_batched_strided.h old mode 100644 new mode 100755 diff --git a/include/cutlass/gemm/threadblock/default_gemv_core.h b/include/cutlass/gemm/threadblock/default_gemv_core.h old mode 100644 new mode 100755 diff --git a/include/cutlass/gemm/threadblock/default_mma.h b/include/cutlass/gemm/threadblock/default_mma.h index 25be0770..11af1de4 100644 --- a/include/cutlass/gemm/threadblock/default_mma.h +++ b/include/cutlass/gemm/threadblock/default_mma.h @@ -33,6 +33,7 @@ #include "cutlass/arch/arch.h" #include "cutlass/arch/wmma.h" +#include "cutlass/layout/matrix.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" @@ -143,8 +144,9 @@ struct DefaultMma; }; +//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass Simt) +/// Specialization for row-major output (OperatorClass TensorOp) template < /// Element type for A matrix operand typename ElementA, @@ -199,8 +201,8 @@ struct DefaultMma; }; -//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// /// Specialization for column-major-interleaved output template < /// Element type for A matrix operand @@ -268,7 +270,9 @@ struct DefaultMma; }; +//////////////////////////////////////////////////////////////////////////////// + #if defined(CUTLASS_ARCH_WMMA_ENABLED) /// Specialization for Wmma TensorOp operator with 2 staged pipeline template < @@ -384,6 +390,8 @@ struct DefaultMma; }; +//////////////////////////////////////////////////////////////////////////////// + /// Specialization for Wmma TensorOp operator with 1 staged pipeline template < ///< Element type for A matrix operand @@ -440,6 +448,7 @@ struct DefaultMma; }; + //////////////////////////////////////////////////////////////////////////////// #endif //CUTLASS_ARCH_WMMA_ENABLED diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm75.h b/include/cutlass/gemm/threadblock/default_mma_core_sm75.h index 490b479e..51b5878f 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm75.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm75.h @@ -598,6 +598,8 @@ struct DefaultMmaCore +struct DefaultMmaPlanarComplexPipelined { + + // Construct a planar complex variant from the real-valued variant + using RealMma = typename DefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator_, + LayoutC_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + Stages, + Operator + >::ThreadblockMma; + + using ThreadblockMma = MmaPlanarComplexPipelined< + ThreadblockShape_, + typename RealMma::IteratorA, + typename RealMma::SmemIteratorA, + typename RealMma::IteratorB, + typename RealMma::SmemIteratorB, + ElementAccumulator_, + LayoutC_, + typename RealMma::Policy, + Stages, + TransformA, + TransformB + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/gemv.h b/include/cutlass/gemm/threadblock/gemv.h old mode 100644 new mode 100755 diff --git a/include/cutlass/gemm/threadblock/mma_pipelined.h b/include/cutlass/gemm/threadblock/mma_pipelined.h index 1e707404..735950cf 100644 --- a/include/cutlass/gemm/threadblock/mma_pipelined.h +++ b/include/cutlass/gemm/threadblock/mma_pipelined.h @@ -118,6 +118,15 @@ public: /// Warp-level Mma using Operator = typename Policy::Operator; + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); @@ -256,8 +265,8 @@ public: __syncthreads(); - ++this->smem_iterator_B_; ++this->smem_iterator_A_; + ++this->smem_iterator_B_; // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory if (smem_write_stage_idx == 1) { @@ -299,7 +308,8 @@ public: } } - warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum); + warp_mma(accum, warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], accum); } } @@ -311,3 +321,5 @@ public: } // namespace threadblock } // namespace gemm } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_planar_complex_base.h b/include/cutlass/gemm/threadblock/mma_planar_complex_base.h new file mode 100644 index 00000000..9491f56f --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_planar_complex_base.h @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaPlanarComplexBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Stride to the imaginary part of the A operand + static int const kImaginaryStrideA = ShapeA::kCount; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + /// Stride to the imaginary part of the A operand + static int const kImaginaryStrideB = ShapeB::kCount; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPlanarComplexBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h b/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h new file mode 100644 index 00000000..ecf722d9 --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h @@ -0,0 +1,422 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_planar_complex_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Transformation applied to A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Transformation applied to B + ComplexTransform TransformB = ComplexTransform::kNone +> +class MmaPlanarComplexPipelined : + public MmaPlanarComplexBase { +public: + ///< Base class + using Base = MmaPlanarComplexBase; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + ///< Policy describing tuning details + using Policy = Policy_; + + using ArchTag = typename Policy::Operator::ArchTag; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + /// Transformation applied to A + static ComplexTransform const kTransformA = TransformA; + + /// Transformation applied to B + static ComplexTransform const kTransformB = TransformB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = ArrayPlanarComplex< + typename Policy::Operator::FragmentC::Element, + Policy::Operator::FragmentC::kElements + >; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + private: + + using FragmentA = typename IteratorA::Fragment; + using FragmentB = typename IteratorB::Fragment; + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPlanarComplexPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + +private: + + CUTLASS_DEVICE + void warp_mma_planar_complex( + Operator & warp_mma, + FragmentC &accum, + WarpFragmentA const & real_A, + WarpFragmentA const & imag_A, + WarpFragmentB const & real_B, + WarpFragmentB const & imag_B) { + + cutlass::negate> neg_op_B; + + WarpFragmentB neg_real_B = neg_op_B(real_B); + WarpFragmentB neg_imag_B = neg_op_B(imag_B); + + warp_mma(accum.real, real_A, real_B, accum.real); + + if (kTransformB == ComplexTransform::kNone) { + warp_mma(accum.imag, real_A, imag_B, accum.imag); + } + else { + warp_mma(accum.imag, real_A, neg_imag_B, accum.imag); + } + + if (kTransformA == ComplexTransform::kNone) { + warp_mma(accum.imag, imag_A, real_B, accum.imag); + } + else { + warp_mma(accum.imag, imag_A, neg_real_B, accum.imag); + } + + if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) { + warp_mma(accum.real, imag_A, imag_B, accum.real); + } + else { + warp_mma(accum.real, imag_A, neg_imag_B, accum.real); + } + } + +public: + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A_real, + ///< iterator over A operand in global memory + IteratorA iterator_A_imag, + ///< iterator over B operand in global memory + IteratorB iterator_B_real, + ///< iterator over B operand in global memory + IteratorB iterator_B_imag, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A_real; + FragmentA tb_frag_A_imag; + + FragmentB tb_frag_B_real; + FragmentB tb_frag_B_imag; + + tb_frag_A_real.clear(); + tb_frag_A_imag.clear(); + + tb_frag_B_real.clear(); + tb_frag_B_imag.clear(); + + // The last kblock is loaded in the prolog + iterator_A_real.load(tb_frag_A_real); + iterator_A_imag.load(tb_frag_A_imag); + + iterator_B_real.load(tb_frag_B_real); + iterator_B_imag.load(tb_frag_B_imag); + + ++iterator_A_real; + ++iterator_A_imag; + + ++iterator_B_real; + ++iterator_B_imag; + + this->smem_iterator_A_.store(tb_frag_A_real); + this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA); + + this->smem_iterator_B_.store(tb_frag_B_real); + this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_real_A[2]; + WarpFragmentA warp_frag_imag_A[2]; + + WarpFragmentB warp_frag_real_B[2]; + WarpFragmentB warp_frag_imag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_real_A[0]); + this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA); + + this->warp_tile_iterator_B_.load(warp_frag_real_B[0]); + this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB); + + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + if (gemm_k_iterations <= 1) { + iterator_A_real.clear_mask(); + iterator_A_imag.clear_mask(); + + iterator_B_real.clear_mask(); + iterator_B_imag.clear_mask(); + } + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(tb_frag_A_real); + this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA); + + this->smem_iterator_B_.store(tb_frag_B_real); + this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB); + + __syncthreads(); + + ++this->smem_iterator_B_; + ++this->smem_iterator_A_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA); + + this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + + iterator_A_real.load(tb_frag_A_real); + iterator_A_imag.load(tb_frag_A_imag); + + iterator_B_real.load(tb_frag_B_real); + iterator_B_imag.load(tb_frag_B_imag); + + ++iterator_A_real; + ++iterator_A_imag; + ++iterator_B_real; + ++iterator_B_imag; + + // Avoid reading out of bounds if this was the last loop iteration + if (gemm_k_iterations <= 2) { + iterator_A_real.clear_mask(); + iterator_A_imag.clear_mask(); + iterator_B_real.clear_mask(); + iterator_B_imag.clear_mask(); + } + } + + warp_mma_planar_complex( + warp_mma, + accum, + warp_frag_real_A[warp_mma_k % 2], + warp_frag_imag_A[warp_mma_k % 2], + warp_frag_real_B[warp_mma_k % 2], + warp_frag_imag_B[warp_mma_k % 2]); + } + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_singlestage.h b/include/cutlass/gemm/threadblock/mma_singlestage.h index 99ec9d64..fd9890a4 100644 --- a/include/cutlass/gemm/threadblock/mma_singlestage.h +++ b/include/cutlass/gemm/threadblock/mma_singlestage.h @@ -168,7 +168,6 @@ public: // Perform accumulation in the 'd' output operand accum = src_accum; - FragmentA tb_frag_A; FragmentB tb_frag_B; @@ -183,8 +182,9 @@ public: ++iterator_B; // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; + WarpFragmentA warp_frag_A; + WarpFragmentB warp_frag_B; + Operator warp_mma; // Avoid reading out of bounds @@ -193,7 +193,6 @@ public: iterator_B.clear_mask(); } - // // Mainloop // @@ -203,7 +202,6 @@ public: this->smem_iterator_A_.store(tb_frag_A); this->smem_iterator_B_.store(tb_frag_B); - __syncthreads(); // @@ -216,16 +214,16 @@ public: // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group // as the case may be. - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k) % 2]); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k) % 2]); + this->warp_tile_iterator_A_.load(warp_frag_A); + this->warp_tile_iterator_B_.load(warp_frag_B); ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum); + warp_mma(accum, warp_frag_A, warp_frag_B, accum); } // Add negative offsets to return smem load iterators to the 'start' of the shared memory diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle.h b/include/cutlass/gemm/threadblock/threadblock_swizzle.h index cd386b47..1beec2c2 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle.h @@ -30,7 +30,8 @@ #pragma once #include "cutlass/cutlass.h" - +#include "cutlass/layout/matrix.h" +#include "cutlass/platform/platform.h" #include "cutlass/gemm/gemm.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -141,6 +142,55 @@ struct GemmIdentityThreadblockSwizzle { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// A special version of GemmIdentityThreadblockSwizzle. See the choice of kTile below. +template +struct GemmCohortThreadblockSwizzle +{ + const int kTile = + (platform::is_same::value || + platform::is_same::value) + ? 4 + : 1; + + CUTLASS_HOST_DEVICE + GemmCohortThreadblockSwizzle() { } + + /// Returns the shape of the problem in units of logical tiles + CUTLASS_HOST_DEVICE + GemmCoord get_tiled_shape( + GemmCoord problem_size, + GemmCoord tile_size, + int split_k_slices) const { + + return GemmCoord( + (problem_size.m() + tile_size.m() - 1) / tile_size.m(), + (problem_size.n() + tile_size.n() - 1) / tile_size.n(), + split_k_slices); + } + + /// Computes CUDA grid dimensions given a size in units of logical tiles + CUTLASS_HOST_DEVICE + dim3 get_grid_shape(GemmCoord tiled_shape) const { + return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k()); + } + + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + GemmCoord get_tile_offset() const { + + int block_idx_x = RematerializeBlockIdxX(); + int block_idx_y = RematerializeBlockIdxY(); + + return GemmCoord{ + (block_idx_x / kTile), + (block_idx_y * kTile) + (block_idx_x % kTile), + RematerializeBlockIdxZ() + }; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Threadblock swizzling function for GEMMs struct GemmHorizontalThreadblockSwizzle { @@ -186,8 +236,8 @@ struct GemmBatchedIdentityThreadblockSwizzle { CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape( GemmCoord problem_size, - int batch_count, - GemmCoord tile_size) const { + GemmCoord tile_size, + int batch_count) const { return GemmCoord( (problem_size.m() + tile_size.m() - 1) / tile_size.m(), @@ -207,7 +257,7 @@ struct GemmBatchedIdentityThreadblockSwizzle { return GemmCoord{ RematerializeBlockIdxX(), RematerializeBlockIdxY(), - 0 + RematerializeBlockIdxZ() }; } diff --git a/include/cutlass/gemm/warp/default_mma_tensor_op.h b/include/cutlass/gemm/warp/default_mma_tensor_op.h index 5bf1b74a..f64f46f9 100644 --- a/include/cutlass/gemm/warp/default_mma_tensor_op.h +++ b/include/cutlass/gemm/warp/default_mma_tensor_op.h @@ -114,3 +114,6 @@ struct DefaultMmaTensorOp { } // namespace warp } // namespace gemm } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_complex_tensor_op.h deleted file mode 100644 index 073b131c..00000000 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op.h +++ /dev/null @@ -1,351 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/complex.h" -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA = ComplexTransform::kNone, - /// Complex transform on B operand - ComplexTransform TransformB = ComplexTransform::kNone, - /// Used for partial specialization - typename Enable = bool -> -class MmaComplexTensorOp; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename RealElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename RealElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename RealElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Complex transform on A operand - ComplexTransform TransformA, - /// Complex transform on B operand - ComplexTransform TransformB, - /// Used for partial specialization - typename Enable -> -class MmaComplexTensorOp< - Shape_, - complex, - LayoutA_, - complex, - LayoutB_, - complex, - LayoutC_, - Policy_, - TransformA, - TransformB, - Enable> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = complex; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = complex; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = complex; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = TransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = TransformB; - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - -public: - - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - 32, - 1 - >; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator< - MatrixShape, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kColumn, - 32, - 1 - >; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - - static_assert( - !(Shape::kM % Policy::Operator::Shape::kM) && - !(Shape::kN % Policy::Operator::Shape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - - /// Number of mma operations performed - using MmaIterations = MatrixShape< - Shape::kM / Policy::Operator::Shape::kM, - Shape::kN / Policy::Operator::Shape::kN - >; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator< - MatrixShape, - ElementC, - LayoutC, - typename Policy::Operator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this - /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued - /// parts are stored consecutively followed by all imaginary parts. This matches the structure - /// of Tensor Cores which are always real-valued matrix multiplies. - using FragmentC = typename IteratorC::Fragment; - - static_assert( - FragmentC::kElements == 2 * MmaIterations::kCount * Policy::Operator::FragmentC::kElements, - "Unexpected planar complex fragment length."); - -private: - - // - // Data members - // - - /// Underlying real-valued matrix multiply operator (concept: arch::Mma) - typename Policy::Operator mma; - -public: - - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaComplexTensorOp() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()( - FragmentC &D, - FragmentA const &A, - FragmentB const &B, - FragmentC const &C) const { - - // Alias types for underlying real-valued matrix multiply operator - using MmaOperandA = typename Policy::Operator::FragmentA; - using MmaOperandB = typename Policy::Operator::FragmentB; - using MmaOperandC = typename Policy::Operator::FragmentC; - - static_assert(MmaOperandA::kElements == 1, - "This implementation only supports math instructions in which exactly one element is needed for the A operand." - "We can geneneralize later."); - - static_assert(MmaOperandB::kElements == 1, - "This implementation only supports math instructions in which exactly one element is needed for the A operand." - "We can geneneralize later."); - - D = C; - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - // mma(accum.real(), a.real(), b.real(), accum.real()); - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - operand_A[0] = A[m].real(); - operand_B[0] = B[n].real(); - - // Real-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_A, operand_B, *accum); - } - - // mma(accum.imag(), a.real(), b.imag(), accum.imag()); - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - operand_A[0] = A[m].real(); - operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag()); - - // Complex-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_A, operand_B, *accum); - } - - // mma(accum.real(), -a.imag(), b.imag(), accum.real()) - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - // A imaginary part is intentionally negated - operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? A[m].imag() : -A[m].imag()); - operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag()); - - // Complex-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow); - - mma(*accum, operand_A, operand_B, *accum); - } - - // mma(accum.imag(), a.imag(), b.real(), accum.imag()) - CUTLASS_PRAGMA_UNROLL - for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { - - // Pack operands together. This may result in actual MOVs - MmaOperandA operand_A; - MmaOperandB operand_B; - - operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? -A[m].imag() : A[m].imag()); - operand_B[0] = B[n].real(); - - // Real-valued accumulator part - MmaOperandC *accum = reinterpret_cast(&D) + - (m + n * MmaIterations::kRow) + MmaIterations::kCount; - - mma(*accum, operand_A, operand_B, *accum); - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// TODO - partial specializations of real*complex and complex*real - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_planar_complex.h b/include/cutlass/gemm/warp/mma_planar_complex.h new file mode 100644 index 00000000..c5790440 --- /dev/null +++ b/include/cutlass/gemm/warp/mma_planar_complex.h @@ -0,0 +1,176 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/array_planar_complex.h" +#include "cutlass/gemm/warp/tile_iterator_planar_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Underlying real-valued warp-level matrix multiply + typename Operator_, + /// Transformation applied to A operand (typically folded into math instruction) + ComplexTransform TransformA = ComplexTransform::kNone, + /// Transformation applied to B operand (typically folded into math instruction) + ComplexTransform TransformB = ComplexTransform::kNone +> +class MmaPlanarComplex { +public: + + /// Underlying real-valued warp-level matrix multiply + using Operator = Operator_; + + /// Shape of warp-level matrix multipy + using Shape = typename Operator::Shape; + + /// Transformation applied to A operand (typically folded into math instruction) + static ComplexTransform const kTransformA = TransformA; + + /// Transformation applied to B operand (typically folded into math instruction) + static ComplexTransform const kTransformB = TransformB; + + /// Fragment of elements + using FragmentA = ArrayPlanarComplex; + + /// Iterator into planar complex + using IteratorA = TileIteratorPlanarComplex; + + /// Layout in memory of the A operand + using LayoutA = typename Operator::LayoutA; + + using FragmentB = ArrayPlanarComplex; + + /// Iterator into planar complex + using IteratorB = TileIteratorPlanarComplex; + + /// Layout in memory of the B operand + using LayoutB = typename Operator::LayoutB; + + /// Tile iterator for accumulator + using IteratorC = TileIteratorPlanarComplex; + + /// Accumulator fragment + using FragmentC = ArrayPlanarComplex; + + /// Layout of accumulator fragment in memory + using LayoutC = typename Operator::LayoutC; + +private: + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + Operator::Shape::kM / Operator::Policy::Operator::Shape::kM, + Operator::Shape::kN / Operator::Policy::Operator::Shape::kN + >; + +public: + /// Ctor + CUTLASS_DEVICE + MmaPlanarComplex() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A_in, + FragmentB const &B_in, + FragmentC const &C) const { + + D.real = C.real; + D.imag = C.imag; + + // + // Transform fragments based on conjugate operations. + // + + negate neg_A; + + FragmentA frag_A; + frag_A.real = A_in.real; + + if (kTransformA == ComplexTransform::kConjugate) { + frag_A.imag = neg_A(frag_A.imag); + } + else { + frag_A.imag = frag_A.imag; + } + + FragmentB frag_B; + frag_B.real = B_in.real; + + if (kTransformB == ComplexTransform::kConjugate) { + negate neg; + frag_B.imag = neg(frag_B.imag); + } + else { + frag_B.imag = frag_B.imag; + } + + // + // Accumulated real-valued matrix multiplies + // + + Operator real_mma; + + // D.i += A.i * B.r + real_mma(D.imag, frag_A.imag, frag_B.real, D.imag); + + // D.r += A.r * B.r + real_mma(D.real, frag_A.real, frag_B.real, D.real); + + // D.i += A.r * B.i + real_mma(D.imag, frag_A.real, frag_B.imag, D.imag); + + // D.r += -A.i * B.i + frag_A.imag = neg_A(frag_A.imag); + real_mma(D.real, frag_A.imag, frag_B.imag, D.real); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_simt.h b/include/cutlass/gemm/warp/mma_simt.h index eecb6aae..9166fe7c 100644 --- a/include/cutlass/gemm/warp/mma_simt.h +++ b/include/cutlass/gemm/warp/mma_simt.h @@ -100,6 +100,16 @@ public: /// Indicates class of matrix operator using OperatorClass = arch::OpClassSimt; + /// Hard-coded for now + using ArchTag = arch::Sm50; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Layout of threads using ThreadLayoutA = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA >::value, layout::ColumnMajor, typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value, @@ -153,6 +163,9 @@ public: /// Storage for A tile using FragmentA = typename IteratorA::Fragment; + /// Storage for transformed A tile + using TransformedFragmentA = FragmentA; + /// Iterates over the B operand in memory using IteratorB = MmaSimtTileIterator< MatrixShape, @@ -167,6 +180,9 @@ public: /// Storage for B tile using FragmentB = typename IteratorB::Fragment; + /// Storage for transformed A tile + using TransformedFragmentB = FragmentB; + /// Iterates over the C operand in memory using IteratorC = MmaSimtTileIterator< MatrixShape, @@ -201,6 +217,15 @@ public: mma(d, a, b, c); } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + //TODO: Implement this + dst_A = A; + dst_B = B; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_tensor_op.h b/include/cutlass/gemm/warp/mma_tensor_op.h index d3e0fc0f..4e082db1 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_tensor_op.h @@ -31,7 +31,9 @@ #include "cutlass/cutlass.h" #include "cutlass/array.h" +#include "cutlass/platform/platform.h" +#include "cutlass/numeric_conversion.h" #include "cutlass/numeric_types.h" #include "cutlass/matrix_shape.h" @@ -51,6 +53,60 @@ namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +template +struct ConvertAndPack { + + using Converter = NumericArrayConverter; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + Converter converter; + + return converter(source); + } +}; + +template +struct ConvertAndPack { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + return source; + } +}; + +template +struct ConvertAndPack { + + using Converter = NumericArrayConverter; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + Converter converter; + + Array tmp; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); + tmp[i] = source[idx]; + } + + return converter(tmp); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> @@ -105,9 +161,18 @@ public: /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) using Policy = Policy_; + /// Architecture tag from underlying instruction + using ArchTag = typename Policy::Operator::ArchTag; + /// Indicates class of matrix operator using OperatorClass = arch::OpClassTensorOp; + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Number of threads participating in warp-level matrix product static int const kThreadCount = 32; @@ -128,6 +193,10 @@ public: /// Storage for A tile using FragmentA = typename IteratorA::Fragment; + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + /// Iterates over the B operand in memory using IteratorB = MmaTensorOpMultiplicandTileIterator< MatrixShape, Operand::kB, ElementB, LayoutB, @@ -137,6 +206,10 @@ public: /// Storage for B tile using FragmentB = typename IteratorB::Fragment; + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + /// Iterates over the C operand in memory using IteratorC = MmaTensorOpAccumulatorTileIterator< MatrixShape, ElementC, LayoutC, @@ -179,8 +252,8 @@ public: CUTLASS_DEVICE void operator()( FragmentC &D, - FragmentA const &A, - FragmentB const &B, + TransformedFragmentA const &A, + TransformedFragmentB const &B, FragmentC const &C, int const &partitionN_idx = 0) const { @@ -221,6 +294,44 @@ public: } } } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + bool midway_depstage = + !(platform::is_same::value && + platform::is_same::value); + + // + // Define conversions from source type to instruction type + // + FloatRoundStyle const kRoundA = + PreferredRoundingMode::kRound; + FloatRoundStyle const kRoundB = + PreferredRoundingMode::kRound; + detail::ConvertAndPack + convert_A; + NumericArrayConverter + convert_B; + Array const *ptr_B = + reinterpret_cast const *>(&B); + Array * + ptr_dst_B = reinterpret_cast *>(&dst_B); + + dst_A = convert_A(A); + + ptr_dst_B[0] = convert_B(ptr_B[0]); + ptr_dst_B[1] = convert_B(ptr_B[1]); + + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -228,3 +339,5 @@ public: } // namespace warp } // namespace gemm } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_tensor_op_sm70.h b/include/cutlass/gemm/warp/mma_tensor_op_sm70.h index 836efb94..59515b5b 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_sm70.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_sm70.h @@ -103,6 +103,15 @@ public: /// Indicates class of matrix operator using OperatorClass = arch::OpClassTensorOp; + /// Architecture tag + using ArchTag = arch::Sm70; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Number of threads participating in warp-level matrix product static int const kThreadCount = 32; diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h index 811ff60e..45048d38 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h @@ -199,7 +199,8 @@ public: // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = + Array; private: @@ -516,7 +517,7 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: @@ -747,7 +748,7 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: @@ -1023,7 +1024,8 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = Array; private: @@ -1151,7 +1153,8 @@ class MmaTensorOpMultiplicandTileIterator< int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; byte_offset_ ^= k_groups_delta * sizeof_bits::value * - Layout::kElementsPerAccess / 8; + Layout::kElementsPerAccess * + Policy::LdsmShape::kContiguous / 8; pointer_ += tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor + whole_tiles * stride_ / sections_; @@ -1406,7 +1409,7 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: /// Underlying tile iterator @@ -1636,7 +1639,7 @@ class MmaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: /// Underlying tile iterator diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h index be271e77..51c5ce26 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h @@ -165,7 +165,8 @@ public: // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = Array; private: @@ -473,7 +474,8 @@ public: // /// Fragment object holding a thread's part of a tile, needs on more time number of registers - using Fragment = Array; + using Fragment = Array; private: @@ -738,7 +740,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: @@ -962,7 +964,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: @@ -1557,7 +1559,9 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = + Array; private: @@ -1869,7 +1873,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: /// Underlying tile iterator @@ -2097,7 +2101,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = typename Base::Fragment; private: /// Underlying tile iterator diff --git a/include/cutlass/gemm/warp/mma_tensor_op_wmma.h b/include/cutlass/gemm/warp/mma_tensor_op_wmma.h index bbfa2dcb..fe69867e 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_wmma.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_wmma.h @@ -106,6 +106,15 @@ public: /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) using Policy = Policy_; + /// Underlying architecture tag + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Indicates class of matrix operator using OperatorClass = arch::OpClassTensorOp; @@ -193,7 +202,6 @@ public: } } } - }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/tile_iterator_planar_complex.h b/include/cutlass/gemm/warp/tile_iterator_planar_complex.h new file mode 100644 index 00000000..a3050c42 --- /dev/null +++ b/include/cutlass/gemm/warp/tile_iterator_planar_complex.h @@ -0,0 +1,244 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/array_planar_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TileIteratorPlanarComplex { +public: + + /// Underlying iterator over real-valued tiles + using TileIterator = TileIterator_; + + /// Underlying element type + using Element = typename TileIterator::Element; + + /// Underlying layout type + using Layout = typename TileIterator::Layout; + + /// TensorRef type for loading element from a tensor + using TensorRef = typename TileIterator::TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Planar complex fragment + using Fragment = ArrayPlanarComplex; + +public: + + /// Underlying tile iterator + TileIterator tile_iterator_; + + /// Offset (in units of bytes) to the imaginary part of the planar complex matrix + LongIndex imaginary_offset_; + +public: + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + TileIteratorPlanarComplex(): imaginary_offset_(0) { } + + /// Constructor from TensorRef + CUTLASS_DEVICE + TileIteratorPlanarComplex( + TensorRef const &ref, + int lane_id, + LongIndex imaginary_offset + ): + tile_iterator_(ref, lane_id), + imaginary_offset_((imaginary_offset * sizeof_bits::value) / 8) { } + + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_DEVICE + TileIteratorPlanarComplex &add_pointer_offset(LongIndex offset) { + + tile_iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + TileIteratorPlanarComplex &add_tile_offset(TensorCoord const &tile_offset) { + + tile_iterator_.add_tile_offset(tile_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + TileIteratorPlanarComplex & operator++() { + ++tile_iterator_; + return *this; + } + + // + // WIP + // + + /// Advances the iterator along the opposite of the advance dimension + CUTLASS_HOST_DEVICE + TileIteratorPlanarComplex & operator--() { + --tile_iterator_; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + TileIteratorPlanarComplex & operator+=(TensorCoord const &tile_offset) { + tile_iterator_.add_tile_offset(tile_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + TileIteratorPlanarComplex & operator-=(TensorCoord const &tile_offset) { + tile_iterator_.add_tile_offset(-tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + tile_iterator_.load_with_byte_offset(frag.real, 0); + tile_iterator_.load_with_byte_offset(frag.imag, imaginary_offset_); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset in units of bytes + Index byte_offset) const { + + tile_iterator_.load_with_byte_offset(frag.real, byte_offset); + tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + + Index byte_offset = (pointer_offset * sizeof_bits::value)/8; + + tile_iterator_.load_with_byte_offset(frag.real, byte_offset); + tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + + tile_iterator_.load_with_byte_offset(frag.real, tile_offset, 0); + tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, imaginary_offset_); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + + Index byte_offset = (pointer_offset * sizeof_bits::value)/8; + + tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset); + tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset + imaginary_offset_); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + + tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset); + tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, byte_offset + imaginary_offset_); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + tile_iterator_.set_kgroup_index(k_group); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/half.h b/include/cutlass/half.h index ba6f0d95..8ac08722 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -78,6 +78,10 @@ enum #include +#if defined(__i386__) || defined(__x86_64__) +#include +#endif + #define F16C_ROUND_NEAREST 0 #if !defined(__CUDA_ARCH__) @@ -110,9 +114,51 @@ __inline unsigned short _cvtss_sh (float __F, const int) { // Linux #include + +#if defined(__i386__) || defined(__x86_64__) +#include +#endif + #define F16C_ROUND_NEAREST (_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC) -#endif +#endif // _MSC_VER + +class CpuId { + + bool f16c_enabled; + + CpuId() { + #if defined(__i386__) || defined(__x86_64__) + #if defined(_MSC_VER) + int exx[4]; + + __cpuid (exx, 1); + f16c_enabled = exx[2] & 0x20000000; + + #else + // GCC / Clang + int eax, ebx, ecx, edx; + + __cpuid (1 , eax, ebx, ecx, edx); + f16c_enabled = ecx & 0x20000000; + #endif + #else + // Arm / PowerPC etc. + f16c_enabled = false; + #endif + } + +public: + + bool is_f16c_supported() const { + return f16c_enabled; + } + + static const CpuId& instance() { + static CpuId cpu; + return cpu; + } +}; #endif // !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -154,10 +200,15 @@ struct alignas(2) half_t { static half_t convert(float const& flt) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) return half_t(__float2half_rn(flt)); - #elif !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C - unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST); - return bitcast(u); #else + + #if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C + if( CpuId::instance().is_f16c_supported() ) { + unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST); + return bitcast(u); + } + #endif + // software implementation rounds toward nearest even unsigned const& s = reinterpret_cast(flt); uint16_t sign = uint16_t((s >> 16) & 0x8000); @@ -248,10 +299,15 @@ struct alignas(2) half_t { static float convert(half_t const& x) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) return __half2float(x.to_half()); - #elif !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C - unsigned short u = x.storage; - return _cvtsh_ss(u); #else + + #if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C + if( CpuId::instance().is_f16c_supported() ) { + unsigned short u = x.storage; + return _cvtsh_ss(u); + } + #endif + uint16_t const &h = x.storage; int sign = ((h >> 15) & 1); int exp = ((h >> 10) & 0x1f); diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index 223346c4..f6951769 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -29,7 +29,11 @@ */ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/platform/platform.h" @@ -48,7 +52,7 @@ struct integer_subbyte { static bool const kSigned = Signed; /// External type - using T = typename std::conditional::type; + using T = typename platform::conditional::type; /// Storage type using Storage = uint8_t; diff --git a/include/cutlass/layout/layout.h b/include/cutlass/layout/layout.h index dda08daf..ba540e77 100644 --- a/include/cutlass/layout/layout.h +++ b/include/cutlass/layout/layout.h @@ -38,7 +38,6 @@ #include "cutlass/layout/matrix.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/tensor.h" -#include "cutlass/layout/tensor_nhwc.h" #include "cutlass/layout/vector.h" #include "cutlass/layout/tensor_op_multiplicand_sm70.h" diff --git a/include/cutlass/layout/matrix.h b/include/cutlass/layout/matrix.h index ba0361c0..2ab907a5 100644 --- a/include/cutlass/layout/matrix.h +++ b/include/cutlass/layout/matrix.h @@ -98,7 +98,7 @@ public: /// Assumes coordinate has convention (row, column) CUTLASS_HOST_DEVICE LongIndex operator()(MatrixCoord const &coord) const { - return coord.row() * stride_[0] + coord.column(); + return LongIndex(coord.row()) * LongIndex(stride_[0]) + coord.column(); } /// Inverse of layout function, mapping linear offset to logical coordinate @@ -134,7 +134,7 @@ public: /// Compute the number of contiguous elements needed to store a tensor with the given size CUTLASS_HOST_DEVICE LongIndex capacity(MatrixCoord const &extent) const { - return extent.row() * stride_[0]; + return LongIndex(extent.row()) * LongIndex(stride_[0]); } }; @@ -191,7 +191,7 @@ public: /// Assumes coordinate has convention (row, column) CUTLASS_HOST_DEVICE LongIndex operator()(MatrixCoord const &coord) const { - return coord.row() + coord.column() * stride_[0]; + return LongIndex(coord.column()) * LongIndex(stride_[0]) + coord.row(); } /// Inverse of layout function, mapping linear offset to logical coordinate @@ -227,7 +227,7 @@ public: /// Compute the number of contiguous elements needed to store a tensor with the given size CUTLASS_HOST_DEVICE LongIndex capacity(MatrixCoord const &extent) const { - return extent.column() * stride_[0]; + return LongIndex(extent.column()) * LongIndex(stride_[0]); } }; @@ -290,7 +290,7 @@ public: LongIndex operator()(MatrixCoord const &coord) const { Index row_major = coord.row() / kInterleave; Index row_minor = coord.row() % kInterleave; - return row_major * stride_[0] + coord.column() * kInterleave + row_minor; + return LongIndex(row_major) * LongIndex(stride_[0]) + LongIndex(coord.column()) * kInterleave + row_minor; } /// Inverse of layout function, mapping linear offset to logical coordinate @@ -397,7 +397,7 @@ public: LongIndex operator()(MatrixCoord const &coord) const { Index column_major = coord.column() / kInterleave; Index column_minor = coord.column() % kInterleave; - return column_major * stride_[0] + coord.row() * kInterleave + column_minor; + return LongIndex(column_major) * LongIndex(stride_[0]) + LongIndex(coord.row()) * kInterleave + column_minor; } /// Inverse of layout function, mapping linear offset to logical coordinate diff --git a/include/cutlass/layout/pitch_linear.h b/include/cutlass/layout/pitch_linear.h index 2a326c77..987c2bb8 100644 --- a/include/cutlass/layout/pitch_linear.h +++ b/include/cutlass/layout/pitch_linear.h @@ -116,6 +116,11 @@ public: return PitchLinearCoord(Base::operator-(b)); } + CUTLASS_HOST_DEVICE + PitchLinearCoord operator-() const { + return PitchLinearCoord(-at(0), -at(1)); + } + /// Element-wise multiplication CUTLASS_HOST_DEVICE PitchLinearCoord operator*(Base const& b) const { @@ -211,7 +216,7 @@ public: /// Assumes coordinate has convention (contiguous, strided) CUTLASS_HOST_DEVICE LongIndex operator()(TensorCoord const &coord) const { - return coord.contiguous() + coord.strided() * stride_[0]; + return LongIndex(coord.contiguous()) + LongIndex(coord.strided()) * LongIndex(stride_[0]); } /// Returns the logical coordinate given an offset. diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index dc2a7c88..2ef4e9d2 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -33,7 +33,11 @@ defined in cutlass/tensor_ref.h. */ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include "assert.h" +#endif #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 228d6327..ef4604cb 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -45,7 +45,7 @@ enum class FloatRoundStyle { round_toward_zero, ///< round toward zero round_to_nearest, ///< round to nearest even round_toward_infinity, ///< round toward infinity - round_toward_neg_infinity, ///< round toward negative infinity + round_toward_neg_infinity ///< round toward negative infinity }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -268,8 +268,7 @@ struct NumericConverterClamp { result_type const kClamp_min = -kClamp_max - 1; bool is_int_min = !(s > kClamp_min); bool is_int_max = !(s < kClamp_max); - - return (is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s))); + return is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s)); } CUTLASS_HOST_DEVICE @@ -295,15 +294,15 @@ struct NumericConverterClamp { CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { - NumericConverter convert_op; + NumericConverter convert_op; - float kClamp_max = float((1 << (sizeof_bits::value - 1)) - 1); - float kClamp_min = -kClamp_max - 1; + double kClamp_max = double((1U << (sizeof_bits::value - 1)) - 1); + double kClamp_min = -kClamp_max - 1; - float source = s; + double source = s; - source = fmaxf(source, kClamp_min); - source = fminf(source, kClamp_max); + source = fmax(source, kClamp_min); + source = fmin(source, kClamp_max); return convert_op(source); } @@ -353,6 +352,24 @@ struct NumericArrayConverter { } }; +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return s; + } +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Array <= Array, round to nearest @@ -498,10 +515,15 @@ struct NumericArrayConverter { } }; + +///////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////// // Conditional guards to enable partial specialization for packed integers -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \ + ((__CUDACC_VER_MAJOR__ > 10) || \ + ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) /// Partial specialization for Array <= Array template < @@ -625,12 +647,13 @@ struct NumericArrayConverter { return convert(s); } }; - -#endif // Conditional guards to enable partial specialization for packed integers +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ + ((__CUDACC_VER_MAJOR__ > 10) || \ + ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) /// Partial specialization for Array <= Array template < @@ -707,4 +730,121 @@ struct NumericArrayConverter { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// FastNumericArrayConverter only works when the source is within center range. +/// Conversion operator for Array +template +struct FastNumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &s) { + result_type result; + NumericArrayConverter convert_; + + return convert_(s); + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { return convert(s); } +}; + +/// Partial specialization for Array <= Array +template +struct FastNumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + int tmp = source[i] + 1262485504 /*0x4B400000*/; + result[i] = reinterpret_cast(tmp) - 12582912.0f; + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { return convert(s); } +}; + +/// Partial specialization for Array <= Array +template +struct FastNumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + float tmp = source[i] + 12582912.0f; + result[i] = reinterpret_cast(tmp); + } + + result[0] = __byte_perm(result[0], result[1], 0x40); + result[2] = __byte_perm(result[2], result[3], 0x40); + result[0] = __byte_perm(result[0], result[2], 0x5410); + + return reinterpret_cast(result[0]); + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { return convert(s); } +}; + +/// Partial specialization for Array <= Array +template +struct FastNumericArrayConverter { + static_assert(!(N % 4), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + FastNumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = + reinterpret_cast *>(&result); + Array const *source_ptr = + reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { return convert(s); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines preferred rounding mode for a pair of types +template +struct PreferredRoundingMode { + static FloatRoundStyle const kRound = FloatRoundStyle::round_to_nearest; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index c44659f0..2282e43e 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -28,7 +28,11 @@ */ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/cutlass.h" diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index 3117cc7c..36d290bb 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -95,7 +95,11 @@ // Dependencies //----------------------------------------------------------------------------- +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #if !defined(__CUDACC_RTC__) //----------------------------------------------------------------------------- diff --git a/include/cutlass/predicate_vector.h b/include/cutlass/predicate_vector.h index 4f6b123c..ac4f0278 100644 --- a/include/cutlass/predicate_vector.h +++ b/include/cutlass/predicate_vector.h @@ -28,10 +28,13 @@ */ #pragma once -#if !defined(__CUDACC_RTC__) +#if defined(__CUDACC_RTC__) +#include +#include +#else #include -#endif #include +#endif #include "cutlass/cutlass.h" diff --git a/include/cutlass/real.h b/include/cutlass/real.h index de1fee29..8fa4d710 100644 --- a/include/cutlass/real.h +++ b/include/cutlass/real.h @@ -30,6 +30,10 @@ namespace cutlass { template struct RealType { using Type = T; + + static T from_real(double x) { + return static_cast(x); + } }; template @@ -38,4 +42,5 @@ static T from_real(double r) { return T(r); } + } // namespace cutlass diff --git a/include/cutlass/reduction/thread/reduction_operators.h b/include/cutlass/reduction/thread/reduction_operators.h index af029124..3eed6209 100644 --- a/include/cutlass/reduction/thread/reduction_operators.h +++ b/include/cutlass/reduction/thread/reduction_operators.h @@ -85,7 +85,13 @@ struct ReduceAdd { plus op; - return op(accumulator, element); + NumericArrayConverter< + ElementAccumulator, + Element, + kCount, + PreferredRoundingMode::kRound> converter; + + return op(accumulator, converter(element)); } }; diff --git a/include/cutlass/semaphore.h b/include/cutlass/semaphore.h index c032c327..94b2eace 100644 --- a/include/cutlass/semaphore.h +++ b/include/cutlass/semaphore.h @@ -66,8 +66,9 @@ public: /// Permit fetching the synchronization mechanism early CUTLASS_DEVICE void fetch() { - - asm volatile ("ld.global.cg.s32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + if (wait_thread) { + asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + } } /// Gets the internal state @@ -80,14 +81,8 @@ public: CUTLASS_DEVICE void wait(int status = 0) { - if (wait_thread) { - while (state != status) { - - fetch(); - - __syncwarp(0x01); - - }; + while( __syncthreads_and(state != status) ) { + fetch(); } __syncthreads(); @@ -99,8 +94,7 @@ public: __syncthreads(); if (wait_thread) { - - asm volatile ("st.global.cg.s32 [%0], %1;\n" : : "l"(lock), "r"(status)); + asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); } } }; diff --git a/include/cutlass/tensor_ref.h b/include/cutlass/tensor_ref.h index a28dba57..6567fe81 100644 --- a/include/cutlass/tensor_ref.h +++ b/include/cutlass/tensor_ref.h @@ -379,6 +379,7 @@ template < typename Element, typename Layout > +CUTLASS_HOST_DEVICE bool TensorRef_aligned(TensorRef const &ref, int alignment) { int const kStrideRank = Layout::kStrideRank; diff --git a/include/cutlass/tensor_ref_planar_complex.h b/include/cutlass/tensor_ref_planar_complex.h new file mode 100644 index 00000000..54611911 --- /dev/null +++ b/include/cutlass/tensor_ref_planar_complex.h @@ -0,0 +1,368 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides, bounds, and a pointer to tensor data. +*/ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct PlanarComplexReference { + + // + // Type definitions + // + + using Element = Element_; + using ComplexElement = complex; + + // + // Data members + // + + Element *real; + Element *imag; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + PlanarComplexReference( + Element *real_ = nullptr, + Element *imag_ = nullptr + ): + real(real_), imag(imag_) { } + + /// Loads the complex element + CUTLASS_HOST_DEVICE + operator complex() const { + return complex{*real, *imag}; + } + + /// Stores a complex element to the location pointed to by the reference + CUTLASS_HOST_DEVICE + PlanarComplexReference &operator=(complex const &rhs) { + *real = rhs.real(); + *imag = rhs.imag(); + return *this; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank + and layout within memory. A TensorRef combines a pointer and a Layout concept + +*/ +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class TensorRefPlanarComplex { + public: + /// Data type of individual access + using Element = Element_; + + /// Complex element type + using ComplexElement = complex; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + static_assert(sizeof_bits::value >= 8, + "Planar complex not suitable for subbyte elements at this time"); + + /// Reference type to an element + using Reference = PlanarComplexReference; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// TensorRef to constant data + using ConstTensorRef = TensorRefPlanarComplex< + typename platform::remove_const::type const, + Layout>; + + /// TensorRef to non-constant data + using NonConstTensorRef = TensorRefPlanarComplex< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// Pointer + Element* ptr_; + + /// Layout object maps logical coordinates to linear offsets + Layout layout_; + + /// Offset to imaginary part + LongIndex imaginary_stride_; + + public: + + // + // Methods + // + + /// Constructs a TensorRef with a pointer and layout object. + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex( + Element *ptr = nullptr, ///< pointer to start of tensor + Layout const &layout = Layout(), ///< layout object containing stride and mapping function + LongIndex imaginary_stride = 0 + ): + ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) { + + } + + /// Converting constructor from TensorRef to non-constant data. + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex( + NonConstTensorRef const &ref ///< TensorRef to non-const data + ): + ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { } + + /// Returns a reference to constant-valued tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(ptr_, layout_, imaginary_stride_); + } + + CUTLASS_HOST_DEVICE + NonConstTensorRef non_const_ref() const { + return NonConstTensorRef( + const_cast::type *>(ptr_), + layout_, + imaginary_stride_); + } + + /// Updates only the pointer + CUTLASS_HOST_DEVICE + void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) { + ptr_ = ptr; + imaginary_stride_ = imaginary_stride; + } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) { + ptr_ = ptr; + layout_ = layout; + imaginary_stride_ = imaginary_stride; + } + + /// Returns true if the TensorRef is non-null + CUTLASS_HOST_DEVICE + bool good() const { + return ptr_ != nullptr; + } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Element * data() const { return ptr_; } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Element * imaginary_data() const { return ptr_ + imaginary_stride_; } + + /// Returns a reference to the element at a given linear index + CUTLASS_HOST_DEVICE + Reference data(LongIndex idx) const { + return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_); + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout & layout() { + return layout_; + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout layout() const { + return layout_; + } + + /// Gets the stride to an imaginary element + LongIndex imaginary_stride() const { + return imaginary_stride_; + } + + /// Gets the stride to an imaginary element + LongIndex &imaginary_stride() { + return imaginary_stride_; + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + Index & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + CUTLASS_HOST_DEVICE + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference at(TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference operator[](TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) { + ptr_ += offset_; + return *this; + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) { + add_pointer_offset(offset(coord)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex operator+(TensorCoord const& b) const { + TensorRefPlanarComplex result(*this); + result.add_coord_offset(b); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & operator+=(TensorCoord const& b) { + add_coord_offset(b); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex operator-(TensorCoord const& b) const { + TensorRefPlanarComplex result(*this); + result.add_pointer_offset(-offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & operator-=(TensorCoord const& b) { + add_pointer_offset(-offset(b)); + return *this; + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorRef ref_real() const { + return cutlass::TensorRef(data(), layout()); + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorRef ref_imag() const { + return cutlass::TensorRef(imaginary_data(), layout()); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE +TensorRefPlanarComplex make_TensorRefPlanarComplex( + Element *ptr, + Layout const &layout, + int64_t imaginary_stride) { + + return TensorRefPlanarComplex(ptr, layout, imaginary_stride); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/tensor_view_planar_complex.h b/include/cutlass/tensor_view_planar_complex.h new file mode 100644 index 00000000..bdd29829 --- /dev/null +++ b/include/cutlass/tensor_view_planar_complex.h @@ -0,0 +1,293 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides and a pointer to tensor data. + + TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, + it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from + data storage and is therefore lightweight and may be embedded in larger tensor objects or + memory structures. + + See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to + linear memory. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref_planar_complex.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Data type of element stored within tensor + typename Element_, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename Layout_ +> +class TensorViewPlanarComplex : public TensorRefPlanarComplex { + public: + + /// Base tensor reference + using Base = cutlass::TensorRefPlanarComplex; + + /// Mapping function from logical coordinate to internal n-D array + using Layout = Layout_; + + /// TensorRef pointing to constant memory + using ConstTensorRef = typename Base::ConstTensorRef; + + /// Underlying TensorRef type + using TensorRef = Base; + + /// Data type of individual access + using Element = Element_; + + /// Reference type to an element + using Reference = Element &; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Coordinate in storage n-D array + using Stride = typename Layout::Stride; + + /// TensorView pointing to constant memory + using ConstTensorView = TensorViewPlanarComplex< + typename platform::remove_const::type const, + Layout>; + + /// TensorView pointing to non-constant memory + using NonConstTensorView = TensorViewPlanarComplex< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// View extent + TensorCoord extent_; + + public: + + // + // Methods + // + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) { + + } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + Element *ptr, ///< pointer to start of tensor + Layout const &layout, ///< layout object containing stride and mapping function + LongIndex imaginary_stride, ///< stride between real and imaginary part + TensorCoord const &extent ///< size of the view in logical coordinates + ): + Base(ptr, layout, imaginary_stride), extent_(extent) { + + } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + TensorRef const &ref, ///< pointer and layout object referencing a tensor + TensorCoord const &extent ///< logical size of tensor + ): + Base(ref), extent_(extent) { + + } + + /// Converting constructor from TensorRef to non-constant data. + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + NonConstTensorView const &view ///< TensorView to non-const data + ): + Base(view), extent_(view.extent_) { } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) { + Base::reset(ptr, layout, imaginary_stride); + this->resize(extent_); + } + + /// Changes the size of the view without affecting pointer or layout + CUTLASS_HOST_DEVICE + void resize(TensorCoord extent) { + this->extent_ = extent; + } + + /// Returns the extent of the view (the size along each logical dimension). + CUTLASS_HOST_DEVICE + TensorCoord const& extent() const { return extent_; } + + /// Returns the extent along a particular logical dimension. + CUTLASS_HOST_DEVICE + Index extent(int dim) const { return extent_.at(dim); } + + /// Determines whether a location is within a tensor + CUTLASS_HOST_DEVICE + bool contains(TensorCoord const& coord) const { + CUTLASS_PRAGMA_UNROLL + for (int dim = 0; dim < kRank; ++dim) { + if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { + return false; + } + } + return true; + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + Base ref() const { + return Base(this->data(), this->layout(), this->imaginary_stride()); + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(this->data(), this->layout()); + } + + /// Returns a TensorView to const data + CUTLASS_HOST_DEVICE + ConstTensorView const_view() const { + return ConstTensorView(const_ref(), extent_); + } + + /// Returns a Tensor_view given location and size quantities + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex subview( + TensorCoord extent, ///< extent of the resulting view + TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view + ) const { + + return TensorViewPlanarComplex(ref(), extent.clamp(extent_ - location)).add_coord_offset(location); + } + + /// Returns the number of scalar elements needed to store tensor. + CUTLASS_HOST_DEVICE + size_t capacity() const { + return Base::layout().capacity(extent_); + } + + /// Returns a TensorView offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex operator+( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorViewPlanarComplex result(*this); + result.add_pointer_offset(this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex& operator+=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(this->offset(b)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex operator-( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorRef result(*this); + result.add_pointer_offset(-this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex& operator-=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(-this->offset(b)); + return *this; + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorView view_real() const { + return cutlass::TensorView(this->data(), this->layout(), extent_); + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorView view_imag() const { + return cutlass::TensorView(this->imaginary_data(), this->layout(), extent_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE TensorViewPlanarComplex make_TensorViewPlanarComplex( + Element *ptr, + Layout const &layout, + typename Layout::LongIndex imaginary_stride, + typename Layout::TensorCoord const &extent) { + + return TensorViewPlanarComplex(ptr, layout, imaginary_stride, extent); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h index ff754cfa..2ab40add 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -128,13 +128,13 @@ class PredicatedTileAccessIterator::value / 8; if (kAdvanceRank) { // advance along strided dimension inc_advance_ = - Shape::kStrided * stride_ * sizeof_bits::value / 8; + Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; } else { // advance along contiguous dimension inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; } - inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) * - ThreadMap::Delta::kStrided * stride_ * + inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * + ThreadMap::Delta::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; }; }; @@ -280,7 +280,7 @@ class PredicatedTileAccessIterator::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + AccessType *frag_ptr = reinterpret_cast(&frag); CUTLASS_PRAGMA_UNROLL @@ -310,10 +315,12 @@ class PredicatedTileIterator(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); if (address_iterator_.valid()) { - frag_ptr[idx] = *ptr; + frag_ptr[idx] = *access_ptr; } ++address_iterator_; } @@ -323,11 +330,17 @@ class PredicatedTileIterator::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { address_iterator_.set_iteration_index(0); AccessType const *frag_ptr = reinterpret_cast(&frag); @@ -340,8 +353,11 @@ class PredicatedTileIterator(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + if (address_iterator_.valid()) { - *(address_iterator_.get() + pointer_offset) = frag_ptr[idx]; + *access_ptr = frag_ptr[idx]; } ++address_iterator_; } @@ -351,7 +367,7 @@ class PredicatedTileIterator::value / 8); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, Index byte_offset) { address_iterator_.set_iteration_index(0); AccessType *frag_ptr = reinterpret_cast(&frag); @@ -157,7 +163,11 @@ public: CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { int access_idx = c + s * ThreadMap::Iterations::kContiguous; - frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); + + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + frag_ptr[access_idx] = *access_ptr; ++address_iterator_; } } @@ -172,6 +182,11 @@ public: /// Store a fragment to memory CUTLASS_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, Index byte_offset) { address_iterator_.set_iteration_index(0); AccessType const *frag_ptr = reinterpret_cast(&frag); @@ -180,7 +195,11 @@ public: CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { int access_idx = c + s * ThreadMap::Iterations::kContiguous; - *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + *access_ptr = frag_ptr[access_idx]; ++address_iterator_; } } @@ -189,7 +208,7 @@ public: /// Store a fragment to memory CUTLASS_DEVICE void store(Fragment const &frag) { - store_with_pointer_offset(frag, 0); + store_with_byte_offset(frag, 0); } }; @@ -567,6 +586,11 @@ class RegularTileIterator::value / 8); + } + + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, Index byte_offset) { address_iterator_.set_iteration_index(0); AccessType const *frag_ptr = reinterpret_cast(&frag); @@ -575,7 +599,11 @@ class RegularTileIterator(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + *access_ptr = frag_ptr[access_idx]; ++address_iterator_; } } @@ -806,3 +834,5 @@ class RegularTileIterator(&frag); - Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { @@ -1185,13 +1185,14 @@ class RegularTileIterator< void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { AccessType const *frag_ptr = reinterpret_cast(&frag); - Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + AccessType *access_ptr = pointer_[(s & 1) ^ ((s >> 1) & 1)]; - access_ptr += 16 * (s / 2); + access_ptr += 16 * (s / 2) + vec_pointer_offset; CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { @@ -1199,8 +1200,7 @@ class RegularTileIterator< for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { int access_offset = - c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + - vec_pointer_offset + i * line_size; + c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + i * line_size; int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * Detail::kIterarionsPerAccess + i; diff --git a/media/docs/functionality.md b/media/docs/functionality.md index 26171ca2..de8da82d 100644 --- a/media/docs/functionality.md +++ b/media/docs/functionality.md @@ -28,6 +28,7 @@ Hyperlinks to relevant unit tests demonstrate how specific template instances ma | **TensorOp** | 75 | 10.2+ | `s4 * s4 + s32 => {s32, s4}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu) | | **TensorOp** | 75 | 10.2+ | `b1 ^ b1 + s32 => {s32, b1}` | { T } x { N } => {N,T} | [example](/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu) | + ## Warp-level Matrix Multiply with Tensor Cores The following table summarizes supported warp level shapes for each TensorOp instruction. diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index 6db005dd..5f459223 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -141,7 +141,7 @@ int main() { } ``` -## Launching a GEMM kernel +## Launching a GEMM kernel in CUDA **Example:** launch a mixed-precision GEMM targeting Volta Tensor Cores. ```c++ @@ -235,9 +235,172 @@ Note, the above could be simplified as follows using helper methods defined in ` }); ``` +# CUTLASS Library + +The [CUTLASS Library](./tools/library) defines an API for managing and executing collections of compiled +kernel instances and launching them from host code without template instantiations in client code. + +The host-side launch API is designed to be analogous to BLAS implementations for convenience, though its +kernel selection procedure is intended only to be functionally sufficient. It may not launch the +optimal tile size for a given problem. It chooses the first available kernel whose data types, +layouts, and alignment constraints satisfy the given problem. Kernel instances and a data structure +describing them are completely available to client applications which may choose to implement their +own selection logic. + +[cuBLAS](https://developer.nvidia.com/cublas) offers the best performance and functional coverage +for dense matrix computations on NVIDIA GPUs. + +The CUTLASS Library is used by the CUTLASS Profiler to manage kernel instances, and it is also used +by several SDK examples. + +* [10_planar_complex](/examples/10_planar_complex/planar_complex.cu) +* [11_planar_complex_array](/examples/11_planar_complex_array/planar_complex_array.cu) + +The CUTLASS Library defines enumerated types describing numeric data types, matrix and tensor +layouts, math operation classes, complex transformations, and more. + +Client applications should specify [`tools/library/include`](/tools/library/include) in their +include paths and link against libcutlas_lib.so. + +The CUTLASS SDK example [10_planar_complex](/examples/10_planar_complex/CMakeLists.txt) specifies +its dependency on the CUTLASS Library with the following CMake command. +``` +target_link_libraries( + 10_planar_complex + PRIVATE + cutlass_lib + cutlass_tools_util_includes +) +``` + +A sample kernel launch from host-side C++ is shown as follows. + +```c++ +#include "cutlass/library/library.h" +#include "cutlass/library/handle.h" + +int main() { + + // + // Define the problem size + // + int M = 512; + int N = 256; + int K = 128; + + float alpha = 1.25f; + float beta = -1.25f; + + // + // Allocate device memory + // + + cutlass::HostTensor A({M, K}); + cutlass::HostTensor B({K, N}); + cutlass::HostTensor C({M, N}); + + float const *ptrA = A.device_data(); + float const *ptrB = B.device_data(); + float const *ptrC = C.device_data(); + float *ptrD = C.device_data(); + + int lda = A.device_ref().stride(0); + int ldb = B.device_ref().stride(0); + int ldc = C.device_ref().stride(0); + int ldd = D.device_ref().stride(0); + + // + // CUTLASS Library call to execute device GEMM + // + + cutlass::library::Handle handle; + + // + // Launch GEMM on CUDA device. + // + + cutlass::Status status = handle.gemm( + M, + N, + K, + + cutlass::library::NumericTypeID::kF32, // data type of internal accumulation + cutlass::library::NumericTypeID::kF32, // data type of alpha/beta scalars + + &alpha, // pointer to alpha scalar + + cutlass::library::NumericTypeID::kF32, // data type of A matrix + cutlass::library::LayoutTypeID::kColumnMajor, // layout of A matrix + ptrA, // pointer to A matrix in device memory + lda, // leading dimension of A matrix + + cutlass::library::NumericTypeID::kF32, // data type of B matrix + cutlass::library::LayoutTypeID::kColumnMajor, // layout of B matrix + ptrB, // pointer to B matrix in device memory + ldb, // leading dimension of B matrix + + &beta, // pointer to beta scalar + + cutlass::library::NumericTypeID::kF32, // data type of C and D matrix + + ptrC, // pointer to C matrix in device memory + ldc, // leading dimension fo C matrix + + ptrD, // pointer to D matrix in device memory + ldd // leading dimension of D matrix + ); + + if (status != cutlass::Status::kSuccess) { + return -1; + } + + return 0; +} +``` + +Kernels can be selectively included in the CUTLASS Library by specifying filter strings when +executing CMake. For example, only single-precision GEMM kernels can be instantiated as follows. + +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=sgemm +``` + +Compling only the kernels desired reduces compilation time. + +To instantiate kernels of all tile sizes, data types, and alignment constraints, specify +`-DCUTLASS_LIBRARY_KERNELS=all` when running `cmake`. + +Several recipes are defined below for convenience. They may be combined as a comma-delimited list. + +**Example.** All kernels for Volta and Turing architectures. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" -DCUTLASS_LIBRARY_KERNELS=all +``` + +**Example.** All GEMM kernels targeting Turing Tensor Cores. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=tensorop*gemm +``` + +**Example.** All GEMM kernels with single-precision accumulation. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" -DCUTLASS_LIBRARY_KERNELS=s*gemm +``` + +**Example.** All kernels which expect A and B to be column-major. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" -DCUTLASS_LIBRARY_KERNELS=gemm*nn +``` + +**Example.** All planar complex GEMM variants. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" -DCUTLASS_LIBRARY_KERNELS=planar_complex +``` + + # Copyright -Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. ``` Redistribution and use in source and binary forms, with or without modification, are permitted diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 5243906e..b03a4d9e 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -33,7 +33,7 @@ target_link_libraries( PUBLIC CUTLASS cutlass_tools_util_includes - $<$:cublas> + $<$:nvidia::cublas> gtest ) @@ -48,6 +48,8 @@ target_link_libraries( PUBLIC cutlass_test_unit_infra ) + +set(CUTLASS_INSTALL_TESTS ON CACHE BOOL "Install test executables") function(cutlass_test_unit_add_executable) @@ -65,7 +67,7 @@ function(cutlass_test_unit_add_executable) PRIVATE cutlass_test_unit_infra cutlass_test_unit_infra_lib - ) + ) string(REGEX REPLACE cutlass_ "" NAME_STEM ${NAME}) @@ -79,7 +81,14 @@ function(cutlass_test_unit_add_executable) ${NAME} ) - # message(STATUS "cutlass_test_unit_add_executable(${NAME} c${NAME_STEM} ${NAME_STEM})") + if (CUTLASS_INSTALL_TESTS) + + install( + TARGETS ${NAME} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) + + endif() endfunction() diff --git a/test/unit/common/filter_architecture.cpp b/test/unit/common/filter_architecture.cpp index 7a6aced0..3bc2823c 100644 --- a/test/unit/common/filter_architecture.cpp +++ b/test/unit/common/filter_architecture.cpp @@ -64,37 +64,22 @@ void FilterArchitecture() { /// Maximum compute capability for which the kernels are enabled int max_compute_capability; - - /// If true, architecture is assumed to be silicon - bool silicon; - } test_filters[] = { - { "SM50*", 50, kMaxDevice, true}, - { "SM60*", 60, kMaxDevice, true}, - { "SM61*", 61, kMaxDevice, true}, - { "SM70*", 70, 75, true}, - { "SM75*", 75, kMaxDevice, true}, + { "SM50*", 50, kMaxDevice}, + { "SM60*", 60, kMaxDevice}, + { "SM61*", 61, kMaxDevice}, + { "SM70*", 70, 75}, + { "SM75*", 75, kMaxDevice}, { 0, 0, false } }; - bool running_on_silicon = false; - for (int i = 0; test_filters[i].filter; ++i) { - if (deviceMajorMinor == test_filters[i].min_compute_capability) { - running_on_silicon = test_filters[i].silicon; - break; - } - } - // Set negative test filters std::stringstream ss; ss << "-"; for (int i = 0, j = 0; test_filters[i].filter; ++i) { - if (!running_on_silicon && deviceMajorMinor != test_filters[i].min_compute_capability) { - ss << (j++ ? ":" : "") << test_filters[i].filter; - } - else if (deviceMajorMinor < test_filters[i].min_compute_capability || + if (deviceMajorMinor < test_filters[i].min_compute_capability || deviceMajorMinor > test_filters[i].max_compute_capability) { ss << (j++ ? ":" : "") << test_filters[i].filter; diff --git a/test/unit/core/functional.cu b/test/unit/core/functional.cu index 2bdbb5e0..ba796655 100644 --- a/test/unit/core/functional.cu +++ b/test/unit/core/functional.cu @@ -347,13 +347,13 @@ TEST(Functional, divides_f16x17) { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -void Functional_multiply_add_f16xN() { +template +void Functional_multiply_add_TxN() { - using Element = cutlass::Array; + using Element = cutlass::Array; using Operator = cutlass::multiply_add; - using Tensor = cutlass::HostTensor; + using Tensor = cutlass::HostTensor; Tensor D({1, kN}); Tensor A({1, kN}); @@ -361,10 +361,10 @@ void Functional_multiply_add_f16xN() { Tensor C({1, kN}); for (int i = 0; i < kN; ++i) { - A.host_data()[i] = cutlass::half_t((i * 2 + 1) % 5); - B.host_data()[i] = cutlass::half_t((i * 4 + 8) % 7); - C.host_data()[i] = cutlass::half_t((i * 3 + 11) % 11); - D.host_data()[i] = cutlass::half_t(0); + A.host_data()[i] = T((i * 2 + 1) % 5); + B.host_data()[i] = T((i * 4 + 8) % 7); + C.host_data()[i] = T((i * 3 + 11) % 11); + D.host_data()[i] = T(0); } D.sync_device(); @@ -399,12 +399,15 @@ void Functional_multiply_add_f16xN() { EXPECT_TRUE(some_d_nonzero); } +///////////////////////////////////////////////////////////////////////////////////////////////// + TEST(Functional, multiply_add_f16x16) { - Functional_multiply_add_f16xN<16>(); + Functional_multiply_add_TxN(); } TEST(Functional, multiply_add_f16x17) { - Functional_multiply_add_f16xN<17>(); + Functional_multiply_add_TxN(); } ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/epilogue/CMakeLists.txt b/test/unit/epilogue/CMakeLists.txt old mode 100644 new mode 100755 index 1948a8ab..8597a79f --- a/test/unit/epilogue/CMakeLists.txt +++ b/test/unit/epilogue/CMakeLists.txt @@ -38,4 +38,4 @@ add_custom_target( test_unit_epilogue_thread test_unit_epilogue_warp test_unit_epilogue_threadblock - ) + ) diff --git a/test/unit/epilogue/thread/CMakeLists.txt b/test/unit/epilogue/thread/CMakeLists.txt index b719784c..81b168a2 100644 --- a/test/unit/epilogue/thread/CMakeLists.txt +++ b/test/unit/epilogue/thread/CMakeLists.txt @@ -23,4 +23,5 @@ cutlass_test_unit_add_executable( cutlass_test_unit_epilogue_thread linear_combination.cu - ) + linear_combination_planar_complex.cu +) diff --git a/test/unit/epilogue/thread/linear_combination_planar_complex.cu b/test/unit/epilogue/thread/linear_combination_planar_complex.cu new file mode 100644 index 00000000..c90b8ad0 --- /dev/null +++ b/test/unit/epilogue/thread/linear_combination_planar_complex.cu @@ -0,0 +1,280 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace epilogue { +namespace thread { + +using FunctorPlanarComplexF32F32 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float>; + +__global__ void epilogue_thread_functor_planar_complex_f32_f32( + float *output_ptr, + float const *accum_ptr, + float const *source_ptr, + typename FunctorPlanarComplexF32F32::Params params) { + + FunctorPlanarComplexF32F32 linear_combination_op(params); + + auto accum = *reinterpret_cast const *>(accum_ptr); + auto source = *reinterpret_cast const *>(source_ptr); + + *reinterpret_cast*>(output_ptr) = linear_combination_op(accum, source); +} + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_linear_combination_planar_complex, f32) { + + using Element = float; + using ElementOutput = float; + int const kCount = 4; + + using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kCount, + Element, + Element>; + + cutlass::complex alpha(Element(2), Element(1)); + cutlass::complex beta(Element(1), Element(-1)); + + typename Functor::Params params(alpha, beta); + + Functor linear_combination_op(params); + + cutlass::ArrayPlanarComplex source; + cutlass::ArrayPlanarComplex accum; + + // Define arbitrary inputs + for (int i = 0; i < kCount; ++i) { + accum.real[i] = Element(i * 2); + accum.imag[i] = Element((i * 3 % 6) - 3); + source.real[i] = ElementOutput((i * 7 % 9) - 4); + source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); + } + + cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); + + // Verify each result + for (int i = 0; i < kCount; ++i) { + + cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + + beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); + + cutlass::complex got(destination.real[i], destination.imag[i]); + + EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); + EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); + EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace epilogue { +namespace thread { + +using FunctorPlanarComplexF16F32 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + cutlass::half_t, + 4, + float, + float>; + +__global__ void epilogue_thread_functor_planar_complex_f16_f32( + cutlass::half_t *output_ptr, + float const *accum_ptr, + cutlass::half_t const *source_ptr, + typename FunctorPlanarComplexF16F32::Params params, + int N) { + + FunctorPlanarComplexF16F32 linear_combination_op(params); + + + auto accum = *reinterpret_cast const *>(accum_ptr); + auto source = *reinterpret_cast const *>(source_ptr); + + #pragma unroll 1 + for (int n = 0; n < N; ++n) { + source = linear_combination_op(accum, source); + } + + *reinterpret_cast*>(output_ptr) = source; +} + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_linear_combination_planar_complex, f16_f32) { + + using Element = float; + using ElementOutput = cutlass::half_t; + int const kCount = 4; + + using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kCount, + Element, + Element>; + + cutlass::complex alpha(Element(2), Element(1)); + cutlass::complex beta(Element(1), Element(-1)); + + typename Functor::Params params(alpha, beta); + + Functor linear_combination_op(params); + + cutlass::ArrayPlanarComplex source; + cutlass::ArrayPlanarComplex accum; + + // Define arbitrary inputs + for (int i = 0; i < kCount; ++i) { + accum.real[i] = Element(i * 2); + accum.imag[i] = Element((i * 3 % 6) - 3); + source.real[i] = ElementOutput((i * 7 % 9) - 4); + source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); + } + + cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); + + // Verify each result + for (int i = 0; i < kCount; ++i) { + + cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + + beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); + + cutlass::complex got(destination.real[i], destination.imag[i]); + + EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); + EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); + EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace epilogue { +namespace thread { + +using FunctorPlanarComplexF16F16 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + cutlass::half_t, + 4, + cutlass::half_t, + cutlass::half_t>; + +__global__ void epilogue_thread_functor_planar_complex_f16_f16( + cutlass::half_t *output_ptr, + cutlass::half_t const *accum_ptr, + cutlass::half_t const *source_ptr, + typename FunctorPlanarComplexF16F16::Params params, + int N) { + + FunctorPlanarComplexF16F16 linear_combination_op(params); + + auto accum = *reinterpret_cast const *>(accum_ptr); + auto source = *reinterpret_cast const *>(source_ptr); + + #pragma unroll 1 + for (int n = 0; n < N; ++n) { + source = linear_combination_op(accum, source); + } + + *reinterpret_cast*>(output_ptr) = source; +} + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_linear_combination_planar_complex, f16_f16) { + + using Element = cutlass::half_t; + using ElementOutput = cutlass::half_t; + int const kCount = 8; + + using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kCount, + Element, + Element>; + + cutlass::complex alpha(Element(2), Element(1)); + cutlass::complex beta(Element(1), Element(-1)); + + typename Functor::Params params(alpha, beta); + + Functor linear_combination_op(params); + + cutlass::ArrayPlanarComplex source; + cutlass::ArrayPlanarComplex accum; + + // Define arbitrary inputs + for (int i = 0; i < kCount; ++i) { + accum.real[i] = Element(i * 2); + accum.imag[i] = Element((i * 3 % 6) - 3); + source.real[i] = ElementOutput((i * 7 % 9) - 4); + source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); + } + + cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); + + // Verify each result + for (int i = 0; i < kCount; ++i) { + + cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + + beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); + + cutlass::complex got(destination.real[i], destination.imag[i]); + + EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); + EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); + EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/threadblock/CMakeLists.txt b/test/unit/epilogue/threadblock/CMakeLists.txt old mode 100644 new mode 100755 index 4785d734..6e10e15c --- a/test/unit/epilogue/threadblock/CMakeLists.txt +++ b/test/unit/epilogue/threadblock/CMakeLists.txt @@ -30,4 +30,5 @@ cutlass_test_unit_add_executable( epilogue_tensor_op.cu epilogue_volta_tensor_op.cu epilogue_wmma_tensor_op_sm70.cu - ) + epilogue_planar_complex.cu +) diff --git a/test/unit/epilogue/threadblock/epilogue_planar_complex.cu b/test/unit/epilogue/threadblock/epilogue_planar_complex.cu new file mode 100644 index 00000000..de2f8696 --- /dev/null +++ b/test/unit/epilogue/threadblock/epilogue_planar_complex.cu @@ -0,0 +1,506 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +// Tensor Op +#include "cutlass/gemm/warp/default_mma_tensor_op.h" + +// Volta Tensor Op +#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" +#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h" + +// Simt +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass/gemm/warp/mma_simt_policy.h" + +// Epilogue components + +#include "cutlass/epilogue/threadblock/default_epilogue_planar_complex.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "testbed_planar_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_f32_f32_tensor_op_64x64_32x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, + InstructionShape, + Element, LayoutA, + Element, LayoutB, + ElementAccumulator, cutlass::layout::RowMajor + >::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaTensorOp, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_f16_f32_tensor_op_64x64_32x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, + InstructionShape, + Element, LayoutA, + Element, LayoutB, + ElementAccumulator, cutlass::layout::RowMajor + >::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaTensorOp, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_f16_f16_tensor_op_64x64_32x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, + InstructionShape, + Element, LayoutA, + Element, LayoutB, + ElementAccumulator, cutlass::layout::RowMajor + >::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaTensorOp, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_f32_f32_volta_tensor_op_64x64_32x32x4) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; + using Element = cutlass::half_t; + + using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; + using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; + + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + cutlass::gemm::GemmShape<16, 16, 4>, + 32, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::layout::RowMajor, + cutlass::arch::OpMultiplyAdd + >, + cutlass::MatrixShape<1, 1> + >; + + using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< + WarpShape, + Element, + LayoutA, + Element, + LayoutB, + ElementAccumulator, + cutlass::layout::RowMajor, + Policy + >; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaTensorOp, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_simt_f32_64x64_32x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 1; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + using Element = float; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using ElementOutput = Element; + using ElementAccumulator = Element; + using ElementCompute = Element; + + using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< + WarpShape, + Element, + LayoutA, + Element, + LayoutB, + Element, + LayoutC, + cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape<4, 8>, + cutlass::layout::RowMajorInterleaved<2>, + cutlass::gemm::GemmShape<4, 4, 1> + > + >; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaSimt, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_threadblock_epilogue, planar_complex_simt_f64_64x64_16x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + int const kElementsPerAccess = 1; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + using Element = double; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using ElementOutput = Element; + using ElementAccumulator = Element; + using ElementCompute = Element; + + using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< + WarpShape, + Element, + LayoutA, + Element, + LayoutB, + Element, + LayoutC, + cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape<4, 8>, + cutlass::layout::RowMajorInterleaved<2>, + cutlass::gemm::GemmShape<4, 4, 1> + > + >; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< + Shape, + WarpMmaSimt, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpiloguePlanarComplexTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/threadblock/epilogue_tensor_op.cu b/test/unit/epilogue/threadblock/epilogue_tensor_op.cu index 6662213d..530ca8f4 100644 --- a/test/unit/epilogue/threadblock/epilogue_tensor_op.cu +++ b/test/unit/epilogue/threadblock/epilogue_tensor_op.cu @@ -34,6 +34,7 @@ #include "cutlass/half.h" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" #include "cutlass/gemm/warp/default_mma_tensor_op.h" #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" @@ -45,6 +46,541 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x64_64x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x64_32x32x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x128_64x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x64_64x32x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x128_32x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_32x128_32x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<32, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x32_64x32x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 32, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + + +TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_256x128_64x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<256, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + + +TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x256_64x64x32) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int; + using ElementCompute = float; + int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 256, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; + using Element = ElementOutput; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_64x64x16) { // diff --git a/test/unit/epilogue/threadblock/testbed_planar_complex.h b/test/unit/epilogue/threadblock/testbed_planar_complex.h new file mode 100644 index 00000000..fca543ae --- /dev/null +++ b/test/unit/epilogue/threadblock/testbed_planar_complex.h @@ -0,0 +1,388 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_planar_complex_threadblock( + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + int64_t imaginary_stride_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + int64_t imaginary_stride_C, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int64_t imaginary_stride_accum, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D_real( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_D_imag( + params_D, + ptr_D + imaginary_stride_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C_real( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_C_imag( + params_C, + ptr_C + imaginary_stride_C, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + // + // Load accumulators + // + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + + accumulator_iterator.load(accumulators.real); + accumulator_iterator.load_with_pointer_offset(accumulators.imag, imaginary_stride_accum); + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop so assembly is clearly visible + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue( + output_op, + iterator_D_real, + iterator_D_imag, + accumulators, + iterator_C_real, + iterator_C_imag); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpiloguePlanarComplexTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + + using ComplexElementOutput = cutlass::complex; + using ComplexElementAccumulator = cutlass::complex; + using ComplexElementCompute = cutlass::complex; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensorPlanarComplex accumulator_tensor; + cutlass::HostTensorPlanarComplex source_tensor; + cutlass::HostTensorPlanarComplex output_tensor; + +public: + + // + // Methods + // + + EpiloguePlanarComplexTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + #if 1 + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + #else + + cutlass::reference::host::BlockFillSequential(accumulator_tensor.host_data(), accumulator_tensor.capacity()); + + #endif + } + + bool run_all() { + + cutlass::complex alpha_values[3]; + + alpha_values[0] = cutlass::complex(1, 0); + alpha_values[1] = cutlass::complex(0, 0); + alpha_values[2] = cutlass::complex(2.25f, -0.5f); + + cutlass::complex beta_values[3]; + + beta_values[0] = cutlass::complex(0, 0); + beta_values[1] = cutlass::complex(1, 0); + beta_values[2] = cutlass::complex(0.5f, -2.25f); + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + cutlass::MatrixCoord problem_size( + quantized_size.row() - m_idx * 3, + quantized_size.column() - n_idx * Epilogue::kElementsPerAccess + ); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = run(problem_size, {alpha, beta}); + + if (!passed) { + return false; + } + } + } + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ComplexElementOutput default_output = ComplexElementOutput(ElementOutput(-127), ElementOutput(-101)); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_planar_complex_threadblock<<< grid, block >>>( + params_D, + output_tensor.device_data(), + output_tensor.imaginary_stride(), + params_C, + source_tensor.device_data(), + source_tensor.imaginary_stride(), + output_params, + problem_size, + accumulator_tensor.device_view_real(), + accumulator_tensor.imaginary_stride() + ); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ComplexElementOutput got = output_tensor.at(coord); + + ComplexElementOutput expected = default_output; + + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + ComplexElementOutput src = source_tensor.at(coord); + + ComplexElementCompute tmp = + output_params.alpha * ComplexElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ComplexElementCompute(src.real(), src.imag()); + + expected = ComplexElementOutput(ElementOutput(tmp.real()), ElementOutput(tmp.imag())); + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // + // Report results on error + // + + if (errors) { + + + std::cout << "Incorrect result for problem(" + << problem_size.row() << ", " + << problem_size.column() << ") for alpha: " << output_params.alpha << ", beta: " << output_params.beta << std::endl; + + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + + std::cout << "Wrote workspace to '" << ss.str() << "'" << std::endl; + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 4750dd8b..750a497b 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -23,9 +23,13 @@ cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device + BATCH_SOURCES ON + BATCH_SIZE 4 + gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu - gemm_f16n_f16t_f16t_tensor_op_f16_sm75_slicedk.cu + gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu + gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu @@ -90,6 +94,7 @@ cutlass_test_unit_add_executable( simt_zgemm_tn_sm50.cu simt_zgemm_tt_sm50.cu + gemm_splitk_serial_tensor_op_sm75.cu gemm_splitk_tensor_op_sm75.cu gemm_splitk_tensor_op_sm70.cu gemm_splitk_simt_sm50.cu @@ -145,5 +150,3 @@ cutlass_test_unit_add_executable( gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu ) - - diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75_slicedk.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu similarity index 100% rename from test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75_slicedk.cu rename to test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75_slicedk.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu similarity index 100% rename from test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75_slicedk.cu rename to test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu diff --git a/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu new file mode 100644 index 00000000..aecee047 --- /dev/null +++ b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu @@ -0,0 +1,131 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-level GEMM API for Planar Complex. +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "testbed_planar_complex.h" + +#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s884_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s884_tn : gemm_planar_complex_s884_tn_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_884, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s884_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s884_nt : gemm_planar_complex_s884_nt_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_884, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu index a737cf5c..feb248d2 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu @@ -65,7 +65,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x256x128_64x64x128) { cutlass::gemm::GemmShape<128, 256, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -97,7 +97,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x128x128_64x64x128) { cutlass::gemm::GemmShape<256, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -129,7 +129,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x128x128_64x64x128) { cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -161,7 +161,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x128x128_32x64x128) { cutlass::gemm::GemmShape<64, 128, 128>, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -193,7 +193,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x64x128_64x32x128) { cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x64x128_32x32x128) { cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu index ebef12f6..22a6d7f4 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu @@ -66,7 +66,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x256x128_64x64x128_8x8 cutlass::gemm::GemmShape<128, 256, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -98,7 +98,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 256x128x128_64x64x128_8x8 cutlass::gemm::GemmShape<256, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -130,7 +130,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x128x128_64x64x128_8x8 cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -162,7 +162,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 64x128x128_32x64x128_8x8x cutlass::gemm::GemmShape<64, 128, 128>, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x64x128_64x32x128_8x8x cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -226,7 +226,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 64x64x128_32x32x128_8x8x3 cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu index 165d404b..a5978933 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu @@ -65,7 +65,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x128_64x64x128) { cutlass::gemm::GemmShape<128, 256, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -97,7 +97,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x128_64x64x128) { cutlass::gemm::GemmShape<256, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -129,7 +129,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x128_64x64x128) { cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -161,7 +161,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x128_32x64x128) { cutlass::gemm::GemmShape<64, 128, 128>, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -193,7 +193,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x128_64x32x128) { cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x128_32x32x128) { cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu index 70e69dea..47f959e0 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu @@ -66,7 +66,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x256x128_64x64x128_8x8 cutlass::gemm::GemmShape<128, 256, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -98,7 +98,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 256x128x128_64x64x128_8x8 cutlass::gemm::GemmShape<256, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -130,7 +130,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x128x128_64x64x128_8x8 cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -162,7 +162,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 64x128x128_32x64x128_8x8x cutlass::gemm::GemmShape<64, 128, 128>, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -194,7 +194,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x64x128_64x32x128_8x8x cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -226,7 +226,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 64x64x128_32x32x128_8x8x3 cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<8, 8, 32>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu new file mode 100644 index 00000000..3766c11e --- /dev/null +++ b/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu @@ -0,0 +1,243 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x256x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x128_32x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x128_64x32x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x128_32x32x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu new file mode 100644 index 00000000..5def3a2b --- /dev/null +++ b/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu @@ -0,0 +1,243 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x256x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x128_32x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x128_64x32x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x128_32x32x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu index 7f4772b9..12b4effe 100644 --- a/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu @@ -65,11 +65,9 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 32x64x64_16x32x64) { cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -101,11 +99,9 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x64x64_32x32x64) { cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -137,11 +133,9 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x64x64_64x32x64) { cutlass::gemm::GemmShape<128, 64, 64>, cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -173,11 +167,9 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x128x64_32x64x64) { cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -209,11 +201,9 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x128x64_64x64x64) { cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -245,11 +235,9 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 256x128x64_64x64x64) { cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -281,11 +269,9 @@ TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x256x64_64x64x64) { cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 64 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute + 64 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu index 4a9906b4..d30a644e 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu @@ -65,7 +65,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x256x64_64x64x64) { cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -97,7 +97,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x128x64_64x64x64) { cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -129,7 +129,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x64_64x64x64) { cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -161,7 +161,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x128x64_32x64x64) { cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -193,7 +193,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x64x64_64x32x64) { cutlass::gemm::GemmShape<128, 64, 64>, cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x64x64_32x32x64) { cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu index 1328be0b..53fcbd23 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu @@ -65,7 +65,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16 cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -96,7 +96,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -130,7 +130,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x1 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 32, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu index ce640e82..15bdacc0 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu @@ -65,7 +65,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x64_64x64x64) { cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -97,7 +97,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x64_64x64x64) { cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -129,7 +129,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x64_64x64x64) { cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -161,7 +161,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x64_32x64x64) { cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -193,7 +193,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x64_64x32x64) { cutlass::gemm::GemmShape<128, 64, 64>, cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x64_32x32x64) { cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu index b0001dbf..dd88e87f 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu @@ -65,7 +65,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16 cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -97,7 +97,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -131,7 +131,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x1 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<32, 8, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, @@ -165,7 +165,7 @@ TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x1 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 32, 16>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu index 6317fd7d..4aa799e5 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu @@ -57,9 +57,8 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x256x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -77,11 +76,10 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; - + EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) @@ -96,9 +94,8 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -116,15 +113,80 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x128x64_32x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x64x64_64x32x64, { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::RowMajor, + int8_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<8, 8, 16>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); + +} ) + +CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x64x64_32x32x64, { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::RowMajor, + int8_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<8, 8, 16>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 32 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); + +} ) + ///////////////////////////////////////////////////////////////////////////////////////////////// #endif diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu index 0f9ee12b..34a1f3be 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu @@ -65,11 +65,9 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -96,11 +94,9 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x1 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -130,11 +126,9 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<32, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -164,11 +158,9 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 32, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu index 2a6f6da3..a881ca27 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu @@ -57,9 +57,8 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x256x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -76,9 +75,8 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -95,9 +93,8 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x64_64x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); @@ -115,14 +112,55 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x128x64_32x64x64, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x64x64_64x32x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 32 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; + + test::gemm::device::Testbed testbed; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x64x64_32x32x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 32 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2>; + + test::gemm::device::Testbed testbed; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + ///////////////////////////////////////////////////////////////////////////////////////////////// #endif diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu index c756def2..d2078582 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu @@ -65,11 +65,9 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -97,11 +95,9 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x1 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -131,11 +127,9 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<32, 8, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 @@ -165,11 +159,9 @@ TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16 cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 32, 16>, - cutlass::epilogue::thread::LinearCombinationClamp< + cutlass::epilogue::thread::FastLinearCombinationClamp< ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator + 128 / cutlass::sizeof_bits::value >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 2 diff --git a/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu b/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu new file mode 100644 index 00000000..224c8fbd --- /dev/null +++ b/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmSplitKSerial_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) { + + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + static const int kStages = 2; + + static const int kAlignmentA = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + ElementA, + ElementB, + ElementOutput, + ElementAccumulator>::kAlignmentA; + + static const int kAlignmentB = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + ElementA, + ElementB, + ElementOutput, + ElementAccumulator>::kAlignmentB; + + static const bool kSplitKSerial = true; + + using Gemm = cutlass::gemm::device::Gemm< + ElementA, + cutlass::layout::ColumnMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + kStages, + kAlignmentA, + kAlignmentB, + kSplitKSerial + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu b/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu index c350b063..725b5feb 100644 --- a/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu +++ b/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu @@ -42,6 +42,7 @@ #include "testbed_splitk.h" +// These tests cannot run unless CUDA 10.1 Toolkit or later is used. #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu b/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu index d78164e0..71b606da 100644 --- a/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu +++ b/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu @@ -42,6 +42,7 @@ #include "testbed_splitk.h" +// These tests cannot run unless CUDA 10.2 Toolkit or later is used. #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/testbed.h b/test/unit/gemm/device/testbed.h index 63d88e9b..57108530 100644 --- a/test/unit/gemm/device/testbed.h +++ b/test/unit/gemm/device/testbed.h @@ -26,6 +26,8 @@ \brief Tests for device-wide GEMM interface */ +#pragma once + #include #include #include @@ -41,20 +43,7 @@ #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/gemm.h" -inline char const *to_string(cutlass::Status status) { - - switch (status) { - case cutlass::Status::kSuccess: return "kSuccess"; - case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; - case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; - case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; - case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; - case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; - case cutlass::Status::kErrorInternal: return "kErrorInternal"; - case cutlass::Status::kInvalid: return "kInvalid"; - } - return "invalid"; -} +#include "testbed_utils.h" namespace test { namespace gemm { @@ -185,9 +174,12 @@ struct Testbed { EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); - - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); @@ -341,18 +333,12 @@ bool TestAllGemm() { (cutlass::platform::is_same::value || cutlass::platform::is_same::value) ? 4 : kAlignment; - - int problem_size_m[] = { - kAlignmentM, 512 - 3*kAlignmentM - }; + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; - int problem_size_n[] = { - kAlignmentN, 512 - 2*kAlignmentN - }; + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; int problem_size_k[] = { - kAlignmentK, Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK - }; + kAlignmentK, Gemm::ThreadblockShape::kK * (Gemm::kStages + 1) - kAlignmentK}; int split_k_slices[] = { 1, 2, 3 @@ -379,6 +365,10 @@ bool TestAllGemm() { continue; } + if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { + continue; + } + for (auto alpha : problem_alpha) { for (auto beta : problem_beta) { diff --git a/test/unit/gemm/device/testbed_complex.h b/test/unit/gemm/device/testbed_complex.h index e3372cdd..1eff58a2 100644 --- a/test/unit/gemm/device/testbed_complex.h +++ b/test/unit/gemm/device/testbed_complex.h @@ -26,6 +26,8 @@ \brief Tests for device-wide GEMM interface */ +#pragma once + #include #include #include @@ -90,6 +92,7 @@ struct TestbedComplex : public Testbed { this->tensor_B.host_ref(), Gemm::kTransformB, beta, + this->tensor_C.host_ref(), this->reference_D.host_ref(), ElementAccumulator(0) ); diff --git a/test/unit/gemm/device/testbed_planar_complex.h b/test/unit/gemm/device/testbed_planar_complex.h new file mode 100644 index 00000000..5642020b --- /dev/null +++ b/test/unit/gemm/device/testbed_planar_complex.h @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +class TestbedPlanarComplex { +public: + + using ElementA = typename Gemm::ElementA; + using LayoutA = typename Gemm::LayoutA; + using ElementB = typename Gemm::ElementB; + using LayoutB = typename Gemm::LayoutB; + using ElementC = typename Gemm::ElementC; + using LayoutC = typename Gemm::LayoutC; + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + // + // Data members + // + + cutlass::gemm::GemmCoord problem_size; + cutlass::HostTensorPlanarComplex tensor_A; + cutlass::HostTensorPlanarComplex tensor_B; + cutlass::HostTensorPlanarComplex tensor_C; + cutlass::HostTensorPlanarComplex tensor_D; + cutlass::HostTensorPlanarComplex tensor_D_ref; + + // + // Methods + // + + TestbedPlanarComplex(cutlass::gemm::GemmCoord const & problem_size): problem_size(problem_size) { + + tensor_A.reset({problem_size.m(), problem_size.k()}); + tensor_B.reset({problem_size.k(), problem_size.n()}); + tensor_C.reset({problem_size.m(), problem_size.n()}); + tensor_D.reset({problem_size.m(), problem_size.n()}); + tensor_D_ref.reset({problem_size.m(), problem_size.n()}, false); + } + + void initialize() { + + uint64_t seed = 1073; + + int scope_max = 8; + int scope_min = -8; + + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed * 2019, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_C.host_view(), seed * 2020, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFill(tensor_D.host_view()); + cutlass::reference::host::TensorFill(tensor_D_ref.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + bool run( + cutlass::complex alpha = {1, 0}, + cutlass::complex beta = {0, 0}) { + + initialize(); + + int batch_count = 1; + + ElementA *ptr_A = tensor_A.device_data(); + ElementB *ptr_B = tensor_B.device_data(); + ElementC *ptr_C = tensor_C.device_data(); + ElementC *ptr_D = tensor_D.device_data(); + + int lda = tensor_A.layout().stride(0); + int ldb = tensor_B.layout().stride(0); + int ldc = tensor_C.layout().stride(0); + int ldd = tensor_D.layout().stride(0); + + int64_t imag_stride_A = tensor_A.imaginary_stride(); + int64_t imag_stride_B = tensor_B.imaginary_stride(); + int64_t imag_stride_C = tensor_C.imaginary_stride(); + int64_t imag_stride_D = tensor_D.imaginary_stride(); + + // + // Launch device kernel + // + + Gemm gemm_op; + + typename Gemm::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + batch_count, + {alpha, beta}, + ptr_A, + ptr_A + imag_stride_A, + ptr_B, + ptr_B + imag_stride_B, + ptr_C, + ptr_C + imag_stride_C, + ptr_D, + ptr_D + imag_stride_D, + lda, + lda, + ldb, + ldb, + ldc, + ldc, + ldd, + ldd + }; + + cutlass::Status status = gemm_op(args); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + + cudaError_t error = cudaDeviceSynchronize(); + + tensor_D.sync_host(); + + // + // Compute reference + // + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + tensor_D_ref.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_D.host_view(), + tensor_D_ref.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("gemm_planar_complex.txt"); + + output + << "A:\n" << tensor_A.host_view() << "\n" + << "B:\n" << tensor_B.host_view() << "\n" + << "C:\n" << tensor_C.host_view() << "\n" + << "Reference:\n" + << tensor_D_ref.host_view() << "\n" + << "Computed:\n" + << tensor_D.host_view() << "\n"; + } + + return passed; + } +}; + +template +bool TestOneGemmPlanarComplex(cutlass::gemm::GemmCoord problem_size) { + + TestbedPlanarComplex testbed(problem_size); + + return testbed.run(); +} + +template +bool TestAllGemmPlanarComplex() { + + int M[] = { + 16, 264, + }; + + int N[] = { + 16, 248, + }; + + int K[] = { + 8, 96, + }; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + cutlass::complex alpha_values[] = { + {ElementCompute(1.25), ElementCompute(-0.5)} + }; + + cutlass::complex beta_values[] = { + {ElementCompute(-2.25), ElementCompute(1.5)} + }; + + for (int m : M) { + for (int n : N) { + for (int k : K) { + + test::gemm::device::TestbedPlanarComplex testbed({m, n, k}); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = testbed.run(alpha, beta); + if (!passed) { + return false; + } + } + } + } + } + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/test/unit/gemm/device/testbed_splitk.h b/test/unit/gemm/device/testbed_splitk.h index 19f2d1fe..c8ae4b4a 100644 --- a/test/unit/gemm/device/testbed_splitk.h +++ b/test/unit/gemm/device/testbed_splitk.h @@ -26,6 +26,8 @@ \brief Tests for device-wide GEMM interface */ +#pragma once + #include #include #include diff --git a/test/unit/gemm/device/testbed_universal.h b/test/unit/gemm/device/testbed_universal.h new file mode 100644 index 00000000..44503e0a --- /dev/null +++ b/test/unit/gemm/device/testbed_universal.h @@ -0,0 +1,480 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedUniversal { + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedUniversal testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllGemmUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentM, 512 - 3*kAlignmentM + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1, 2, 3, 5, 7 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + if (k / batch_count < 2 * Gemm::ThreadblockShape::kK) { + continue; + } + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + } + + /* + // large problem with high coverage + for (int split_k_slices = 1; split_k_slices <= 3; ++split_k_slices) { + TestbedUniversal testbed; + + cutlass::gemm::GemmCoord problem_size(72, 56, 8192); + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + split_k_slices, + cutlass::from_real(1.0), + cutlass::from_real(2.0) + ); + + if (!passed) { + break; + } + } + */ + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/testbed_utils.h b/test/unit/gemm/device/testbed_utils.h new file mode 100644 index 00000000..5a76c3be --- /dev/null +++ b/test/unit/gemm/device/testbed_utils.h @@ -0,0 +1,46 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +inline char const *to_string(cutlass::Status status) { + + switch (status) { + case cutlass::Status::kSuccess: return "kSuccess"; + case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; + case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; + case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; + case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; + case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; + case cutlass::Status::kErrorInternal: return "kErrorInternal"; + case cutlass::Status::kInvalid: return "kInvalid"; + } + return "invalid"; +} diff --git a/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/test/unit/gemm/threadblock/mma_planar_complex_testbed.h new file mode 100644 index 00000000..5838e4f3 --- /dev/null +++ b/test/unit/gemm/threadblock/mma_planar_complex_testbed.h @@ -0,0 +1,345 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + *modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + *notice, this list of conditions and the following disclaimer in the + *documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + *contributors may be used to endorse or promote products derived from this + *software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, + *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING + *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma_planar_complex( + cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::Element *ptr_A, + int64_t imaginary_stride_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::Element *ptr_B, + int64_t imaginary_stride_B, + typename Mma::ElementC *ptr_C, int ldc, int64_t imaginary_stride_C) { + + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A operand + typename Mma::IteratorA iterator_A_real(params_A, ptr_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorA iterator_A_imag(params_A, ptr_A + imaginary_stride_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + // Construct iterators to B operand + typename Mma::IteratorB iterator_B_real(params_B, ptr_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorB iterator_B_imag(params_B, ptr_B + imaginary_stride_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum.real); + + iterator_C.store_with_pointer_offset(accum.imag, imaginary_stride_C); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename Mma_> +struct TestbedPlanarComplex { + + using Mma = Mma_; + using ThreadblockShape = typename Mma::Shape; + using IteratorA = typename Mma::IteratorA; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using IteratorB = typename Mma::IteratorB; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Mma::ElementC; + using ElementAccumulator = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + using ThreadMapA = typename Mma::IteratorA::ThreadMap; + using ThreadMapB = typename Mma::IteratorB::ThreadMap; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = Mma::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + Mma::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + Mma::kCacheOpB; + + // + // Data members + // + + cutlass::HostTensorPlanarComplex matrix_A; + cutlass::HostTensorPlanarComplex matrix_B; + cutlass::HostTensorPlanarComplex matrix_C_computed; + cutlass::HostTensorPlanarComplex matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + + // + // Methods + // + + /// Allocates workspace in device memory + TestbedPlanarComplex(int m, int n, int k) + : problem_size(m, n, k) { + + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + + } else if (init_A == cutlass::Distribution::Sequential) { + + for (int i = 0; i < matrix_A.capacity() * 2; ++i) { + matrix_A.host_data()[i] = cutlass::half_t(float(i % 5) - 2); + } + /* + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity() * 2); + */ + } else if (init_A == cutlass::Distribution::Identity) { + //cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + + + } else if (init_B == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity() * 2); + + for (int i = 0; i < matrix_B.capacity() * 2; ++i) { + matrix_B.host_data()[i] = cutlass::half_t(float((i + 3) % 5) - 2); + } + + + } else if (init_B == cutlass::Distribution::Identity) { + + //cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + + } else { + // TODO: Implement the rest + return false; + } + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + test::gemm::threadblock::kernel_mma_planar_complex<<>>( + problem_size, + params_A, + matrix_A.device_data(), + matrix_A.imaginary_stride(), + params_B, + matrix_B.device_data(), + matrix_B.imaginary_stride(), + matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), + matrix_C_computed.imaginary_stride() + ); + + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + cutlass::complex(ElementAccumulator(1)), + matrix_A.host_ref(), + Mma::kTransformA, + matrix_B.host_ref(), + Mma::kTransformB, + cutlass::complex(ElementAccumulator(0)), + matrix_C_reference.host_ref(), + matrix_C_reference.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), + matrix_C_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/test/unit/gemm/warp/CMakeLists.txt b/test/unit/gemm/warp/CMakeLists.txt index 96cfc29b..600d1d8e 100644 --- a/test/unit/gemm/warp/CMakeLists.txt +++ b/test/unit/gemm/warp/CMakeLists.txt @@ -30,5 +30,4 @@ cutlass_test_unit_add_executable( wmma_sm70.cu wmma_sm72.cu wmma_sm75.cu - testbed.h ) diff --git a/test/unit/gemm/warp/testbed.h b/test/unit/gemm/warp/testbed.h index 47ab7bf0..9560b910 100644 --- a/test/unit/gemm/warp/testbed.h +++ b/test/unit/gemm/warp/testbed.h @@ -102,6 +102,7 @@ __global__ void kernel( FragmentA frag_A; FragmentB frag_B; + FragmentC accum; Mma mma; @@ -306,13 +307,22 @@ struct Testbed { if (!passed) { - cutlass::TensorView tensor_A_physical(tensor_A.host_data(), tensor_A.stride(), tensor_A.extent()); - cutlass::TensorView tensor_B_physical(tensor_B.host_data(), tensor_B.stride(), tensor_B.extent()); + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride(), + tensor_A.extent()); + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride(), + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "A:\n" << tensor_A.host_view() << "\n\n" << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "B:\n" << tensor_B.host_view() << "\n\n" << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; @@ -459,6 +469,7 @@ struct TestbedComplex { tensor_B.host_ref(), Mma::kTransformB, ElementC(0), + tensor_C.host_ref(), tensor_D_reference.host_ref() ); @@ -486,13 +497,15 @@ struct TestbedComplex { tensor_B.stride(), tensor_B.extent()); + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "A:\n" << tensor_A.host_view() << "\n\n" << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "B:\n" << tensor_B.host_view() << "\n\n" - << "B(physical):\n" << tensor_B_physical << "\n\n"; + << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; std::cout << "C:\n" << tensor_C.host_view() << "\n\n" @@ -506,6 +519,484 @@ struct TestbedComplex { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Test kernel +template +__global__ void kernel_transform( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (int i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (int i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + + using TransformedFragmentA = typename Mma::TransformedFragmentA; + using TransformedFragmentB = typename Mma::TransformedFragmentB; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); + typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::LaneId()); + + FragmentA loaded_frag_A; + FragmentB loaded_frag_B; + TransformedFragmentA transformed_frag_A; + TransformedFragmentB transformed_frag_B; + + FragmentC accum; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(loaded_frag_A); + iter_B.load(loaded_frag_B); + + ++iter_A; + ++iter_B; + + mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, + loaded_frag_B); + + mma(accum, transformed_frag_A, transformed_frag_B, accum); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct TransformTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformTestbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<>>( + tensor_D_computed.device_data(), tensor_A.device_data(), + tensor_B.device_data(), tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride(), + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride(), + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_ +> +struct TransformedTestbedComplex { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformedTestbedComplex() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), + seed, 8, -8, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), + seed + 16, 8, -8, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::GemmComplex( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + Mma::kTransformA, + tensor_B.host_ref(), + Mma::kTransformB, + ElementC(0), + tensor_C.host_ref(), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride(), + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride(), + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + } // namespace warp } // namespace gemm } // namespace test diff --git a/test/unit/nvrtc/stdlib/stdint.h b/test/unit/nvrtc/stdlib/stdint.h index d066380e..50ed027d 100644 --- a/test/unit/nvrtc/stdlib/stdint.h +++ b/test/unit/nvrtc/stdlib/stdint.h @@ -33,3 +33,91 @@ typedef int int32_t; typedef unsigned int uint32_t; typedef long long int int64_t; typedef unsigned long long int uint64_t; + +#if defined __x86_64__ && !defined __ILP32__ +# define __WORDSIZE 64 +#else +# define __WORDSIZE 32 +#endif + + +/* Small types. */ + +/* Signed. */ +typedef signed char int_least8_t; +typedef short int int_least16_t; +typedef int int_least32_t; +#if __WORDSIZE == 64 +typedef long int int_least64_t; +#else +__extension__ +typedef long long int int_least64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_least8_t; +typedef unsigned short int uint_least16_t; +typedef unsigned int uint_least32_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_least64_t; +#else +__extension__ +typedef unsigned long long int uint_least64_t; +#endif + + +/* Fast types. */ + +/* Signed. */ +typedef signed char int_fast8_t; +#if __WORDSIZE == 64 +typedef long int int_fast16_t; +typedef long int int_fast32_t; +typedef long int int_fast64_t; +#else +typedef int int_fast16_t; +typedef int int_fast32_t; +__extension__ +typedef long long int int_fast64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_fast8_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_fast16_t; +typedef unsigned long int uint_fast32_t; +typedef unsigned long int uint_fast64_t; +#else +typedef unsigned int uint_fast16_t; +typedef unsigned int uint_fast32_t; +__extension__ +typedef unsigned long long int uint_fast64_t; +#endif + +/* Types for `void *' pointers. */ +#if __WORDSIZE == 64 +# ifndef __intptr_t_defined +typedef long int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned long int uintptr_t; +#else +# ifndef __intptr_t_defined +typedef int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned int uintptr_t; +#endif + + +/* Largest integral types. */ +#if __WORDSIZE == 64 +typedef long int intmax_t; +typedef unsigned long int uintmax_t; +#else +__extension__ +typedef long long int intmax_t; +__extension__ +typedef unsigned long long int uintmax_t; +#endif + diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index d32b6fd3..d5e5bd47 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -52,13 +52,15 @@ install( # cutlass_add_library( - cutlass_lib - SHARED - src/library.cu + cutlass_library_objs + OBJECT + src/handle.cu src/manifest.cpp + src/operation_table.cu + src/singleton.cu + src/util.cu + ) -add_library(nvidia::cutlass::library ALIAS cutlass_lib) -set_target_properties(cutlass_lib PROPERTIES EXPORT_NAME library) file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/scripts/*.py) @@ -66,16 +68,19 @@ file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOU # auto-instantiation of CUTLASS kernels # +# set cutlass generator compiler version to filter kernels in the generator not supported by a specific toolkit. +set(CUTLASS_GENERATOR_CUDA_COMPILER_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) + execute_process( WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/scripts COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/scripts/generator.py - --operations all + --operations "${CUTLASS_LIBRARY_OPERATIONS}" --build-dir ${PROJECT_BINARY_DIR} --curr-build-dir ${CMAKE_CURRENT_BINARY_DIR} --generator-target library --architectures "${CUTLASS_NVCC_ARCHS_ENABLED}" --kernels "${CUTLASS_LIBRARY_KERNELS}" - --cuda-version "${CMAKE_CUDA_COMPILER_VERSION}" + --cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}" RESULT_VARIABLE cutlass_lib_INSTANCE_GENERATION_RESULT OUTPUT_VARIABLE cutlass_lib_INSTANCE_GENERATION_OUTPUT OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log @@ -95,35 +100,70 @@ else() endif() target_include_directories( - cutlass_lib + cutlass_library_objs PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src ${CMAKE_CURRENT_BINARY_DIR}/include ) -set_target_properties( - cutlass_lib - PROPERTIES - OUTPUT_NAME cutlass - WINDOWS_EXPORT_ALL_SYMBOLS 1 - ) - target_link_libraries( - cutlass_lib + cutlass_library_objs PUBLIC + cutlass_library_includes + ) + +function(cutlass_add_cutlass_library) + + set(options) + set(oneValueArgs NAME TYPE EXPORT_NAME) + set(multiValueArgs) + cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + cutlass_add_library( + ${__NAME} + ${__TYPE} + EXPORT_NAME ${__EXPORT_NAME} + $ + ) + + target_link_libraries( + ${__NAME} + PUBLIC cutlass_library_includes + ) + + set_target_properties(${__NAME} PROPERTIES DEBUG_POSTFIX ${CUTLASS_LIBRARY_DEBUG_POSTFIX}) + + set(OUTPUT_NAME cutlass) + + if (WIN32 AND ${__TYPE} STREQUAL "STATIC") + set(OUTPUT_NAME "${OUTPUT_NAME}.static") + endif() + + set_target_properties( + ${__NAME} + PROPERTIES + OUTPUT_NAME ${OUTPUT_NAME} + WINDOWS_EXPORT_ALL_SYMBOLS 1 + ) + +endfunction() + +cutlass_add_cutlass_library(NAME cutlass_lib TYPE SHARED EXPORT_NAME library) +cutlass_add_cutlass_library(NAME cutlass_library_static TYPE STATIC EXPORT_NAME library_static) + +install( + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) install( - DIRECTORY - ${CMAKE_CURRENT_SOURCE_DIR}/include/ - DESTINATION - ${CMAKE_INSTALL_INCLUDEDIR} - ) - -install( - TARGETS cutlass_lib cutlass_library_includes + TARGETS + cutlass_lib + cutlass_library_static + cutlass_library_includes EXPORT NvidiaCutlass RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ) diff --git a/tools/library/include/cutlass/library/handle.h b/tools/library/include/cutlass/library/handle.h new file mode 100644 index 00000000..1b60eb7d --- /dev/null +++ b/tools/library/include/cutlass/library/handle.h @@ -0,0 +1,284 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief BLAS-like handle used to launch operations on the CUDA device. +*/ + +#pragma once + +#include +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Handle object +class Handle { +private: + + /// Host workspace + static int const kHostWorkspaceSize = (4 << 10); + + /// CUDA device properties + cudaDeviceProp device_; + + /// CUDA stream + cudaStream_t stream_; + + /// Device workspace + void *workspace_; + + /// Size of device workspace in bytes + size_t workspace_size_; + + /// Indicates whether scalars are host or device pointers + ScalarPointerMode scalar_pointer_mode_; + + /// Pointer to the most recently executed operation + Operation const *last_operation_; + +public: + + /// Constructor + Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20)); + + /// Destructor + ~Handle(); + + /// Move constructor + Handle(Handle && handle); + + /// Move assignment operator + Handle &operator=(Handle && handle); + + // + // Persistent state accessors + // + + /// Returns compute capability of the selected device + int compute_capability() const; + + /// Sets the current CUDA stream + void set_stream(cudaStream_t stream); + + /// Gets the current CUDA stream + cudaStream_t get_stream() const; + + /// Gets the device workspace size + size_t get_workspace_size() const; + + /// Gets a pointer to the device workspace allocation in Global Memory + void *get_workspace() const; + + /// Sets the size of device workspace, invalidating calls to get_device_workspace() + void set_workspace_size(size_t bytes); + + /// Gets the scalar pointer mode + ScalarPointerMode get_scalar_pointer_mode() const; + + /// Sets the scalar pointer mode + void set_scalar_pointer_mode(ScalarPointerMode mode); + + /// Gets the most recently executed operation + Operation const *get_last_operation() const; + + // + // Computations + // + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C + Status gemm( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + + void const * ptr_A, /// Pointer to A matrix in Global Memory + int lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + + void const * ptr_B, /// Pointer to B matrix in Global Memory + int ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrices + + void const * ptr_C, /// Pointer to C matrix + int ldc, /// Leading dimension of C matrix + + void * ptr_D, /// Pointer to D matrix + int ldd /// Leading dimension of D matrix + ); + + /// Planar complex GEMM + /// + /// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel. + /// + Status gemm_planar_complex( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * ptr_A_real, /// Pointer to real part of A matrix + void const * ptr_A_imag, /// Pointer to imaginary part of A matrix + int lda_real, /// Leading dimension of real part of A matrix + int lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * ptr_B_real, /// Pointer to real part of B matrix + void const * ptr_B_imag, /// Pointer to imaginary part of B matrix + int ldb_real, /// Leading dimension of real part of B matrix + int ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * ptr_C_real, /// Pointer to real part of C matrix + void const * ptr_C_imag, /// Pointer to imaginary part of C matrix + int ldc_real, /// Leading dimension of real part of C matrix + int ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * ptr_D_real, /// Pointer to real part of D matrix + void * ptr_D_imag, /// Pointer to imaginary part of D matrix + int ldd_real, /// Leading dimension of real part of D matrix + int ldd_imag, /// Leading dimension of imaginary part of D matrix + + int batch_count = 1, /// Number of batched GEMMs to execute + + int64_t batch_stride_A_real = 0, + int64_t batch_stride_A_imag = 0, + + int64_t batch_stride_B_real = 0, + int64_t batch_stride_B_imag = 0, + + int64_t batch_stride_C_real = 0, + int64_t batch_stride_C_imag = 0, + + int64_t batch_stride_D_real = 0, + int64_t batch_stride_D_imag = 0 + ); + + /// Planar complex GEMM loading pointers from arrays in global memory + Status gemm_planar_complex_array( + + int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) + int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) + int expected_K, /// Expected GEMM K dimension + int batch_count, /// Number of independent GEMM computations to execute + + int const *M, /// Array containing the GEMM M dimension for each batch index + int const *N, /// Array containing the GEMM N dimension for each batch index + int const *K, /// Array containing the GEMM K dimension for each batch index + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices + void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices + + int lda_real, /// Leading dimension of real part of A matrix + int lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices + void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices + + int ldb_real, /// Leading dimension of real part of B matrix + int ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices + void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices + + int ldc_real, /// Leading dimension of real part of C matrix + int ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices + void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices + + int ldd_real, /// Leading dimension of real part of D matrix + int ldd_imag /// Leading dimension of imaginary part of D matrix + ); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Unique pointer storing the handle +using HandlePtr = std::unique_ptr; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 787da7cb..f58e3a45 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -68,6 +68,10 @@ enum class LayoutTypeID { kRowMajorInterleavedK4, kColumnMajorInterleavedK16, kRowMajorInterleavedK16, + kColumnMajorInterleavedK32, + kRowMajorInterleavedK32, + kColumnMajorInterleavedK64, + kRowMajorInterleavedK64, kTensorNCHW, kTensorNHWC, kInvalid @@ -110,9 +114,21 @@ enum class NumericTypeID { /// Enumeraed type describing a transformation on a complex value. enum class ComplexTransform { kNone, - kConjugate + kConjugate, + kInvalid }; +/// Providers +enum class Provider { + kCUTLASS, + kReferenceHost, + kReferenceDevice, + kCUBLAS, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Enumeration indicating the kind of operation enum class OperationKind { kGemm, @@ -143,6 +159,14 @@ enum class OpcodeClassID { kInvalid }; +enum class MathOperationID { + kMultiplyAdd, + kMultiplyAddSaturate, + kMultiplyAddComplex, + kXorPopc, + kInvalid +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Enumeration indicating what kind of GEMM operation to perform @@ -150,88 +174,20 @@ enum class GemmKind { kGemm, kBatched, kArray, + kUniversal, kPlanarComplex, - kPlanarComplexBatched, + kPlanarComplexArray, kInvalid }; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Lexical cast from string -template T from_string(std::string const &); - -/// Converts a NumericType enumerant to a string -char const *to_string(OperationKind type, bool pretty = false); - -/// Parses a NumericType enumerant from a string -template <> OperationKind from_string(std::string const &str); - -/// Converts a NumericType enumerant to a string -char const *to_string(NumericTypeID type, bool pretty = false); - -/// Parses a NumericType enumerant from a string -template <> NumericTypeID from_string(std::string const &str); - -/// Returns the size of a data type in bits -int sizeof_bits(NumericTypeID type); - -/// Returns true if the numeric type is a complex data type or false if real-valued. -bool is_complex_type(NumericTypeID type); - -/// Returns the real-valued type underlying a type (only different from 'type' if complex) -NumericTypeID get_real_type(NumericTypeID type); - -/// Returns true if numeric type is integer -bool is_integer_type(NumericTypeID type); - -/// Returns true if numeric type is signed -bool is_signed_type(NumericTypeID type); - -/// Returns true if numeric type is a signed integer -bool is_signed_integer(NumericTypeID type); - -/// returns true if numeric type is an unsigned integer -bool is_unsigned_integer(NumericTypeID type); - -/// Returns true if numeric type is floating-point type -bool is_float_type(NumericTypeID type); - -/// To string method for cutlass::Status -char const *to_string(Status status, bool pretty = false); - -/// Converts a LayoutTypeID enumerant to a string -char const *to_string(LayoutTypeID layout, bool pretty = false); - -/// Parses a LayoutType enumerant from a string -template <> LayoutTypeID from_string(std::string const &str); - -/// Returns the rank of a layout's stride base on the LayoutTypeID -int get_layout_stride_rank(LayoutTypeID layout_id); - -/// Converts a OpcodeClassID enumerant to a string -char const *to_string(OpcodeClassID type, bool pretty = false); - -/// Converts a OpcodeClassID enumerant from a string -template <> -OpcodeClassID from_string(std::string const &str); - -/// Lexical cast from int64_t to string -std::string lexical_cast(int64_t int_value); - -/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. -bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); - -/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. -std::string lexical_cast(std::vector &bytes, NumericTypeID type); - -/// Casts from a signed int64 to the destination type. Returns true if successful. -bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); - -/// Casts from an unsigned int64 to the destination type. Returns true if successful. -bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); - -/// Casts from a real value represented as a double to the destination type. Returns true if successful. -bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); +/// Mode of GEMM +enum class GemmUniversalMode { + kGemm, + kGemmSplitKParallel, + kBatched, + kArray, + kInvalid +}; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -246,6 +202,9 @@ struct MathInstructionDescription { /// Classification of math instruction OpcodeClassID opcode_class; + /// Type of math operation performed + MathOperationID math_operation; + // // Methods // @@ -253,9 +212,13 @@ struct MathInstructionDescription { MathInstructionDescription( cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), NumericTypeID element_accumulator = NumericTypeID::kInvalid, - OpcodeClassID opcode_class = OpcodeClassID::kInvalid + OpcodeClassID opcode_class = OpcodeClassID::kInvalid, + MathOperationID math_operation = MathOperationID::kMultiplyAdd ): - instruction_shape(instruction_shape), element_accumulator(element_accumulator), opcode_class(opcode_class) {} + instruction_shape(instruction_shape), + element_accumulator(element_accumulator), + opcode_class(opcode_class), + math_operation(math_operation) {} }; @@ -306,6 +269,9 @@ struct OperationDescription { /// Unique identifier describing the operation char const * name; + /// Operation provider + Provider provider; + /// Kind of operation OperationKind kind; @@ -317,6 +283,7 @@ struct OperationDescription { // OperationDescription( char const * name = "unknown", + Provider Provider = Provider::kInvalid, OperationKind kind = OperationKind::kInvalid, TileDescription const & tile_description = TileDescription() ): @@ -340,10 +307,11 @@ struct TensorDescription { /// log2() of the maximum value each relevant stride may have int log_stride_range; - + // // Methods // + TensorDescription( NumericTypeID element = NumericTypeID::kInvalid, LayoutTypeID layout = LayoutTypeID::kInvalid, @@ -355,7 +323,7 @@ struct TensorDescription { layout(layout), alignment(alignment), log_extent_range(log_extent_range), - log_stride_range(log_stride_range) { } + log_stride_range(log_stride_range) { } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -414,7 +382,7 @@ struct GemmDescription : public OperationDescription { ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Base class for all device-wide operations +/// Base class for all operations class Operation { public: @@ -435,7 +403,7 @@ public: virtual Status initialize( void const *configuration, void *host_workspace, - void *device_workspace, + void *device_workspace = nullptr, cudaStream_t stream = nullptr) const = 0; virtual Status run( @@ -443,6 +411,7 @@ public: void *host_workspace, void *device_workspace = nullptr, cudaStream_t stream = nullptr) const = 0; + }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -551,11 +520,18 @@ using GemmBatchedArguments = GemmArguments; struct GemmArrayConfiguration { gemm::GemmCoord problem_size; + + /// Leading dimension of A matrix + int64_t lda; - int64_t const *lda; - int64_t const *ldb; - int64_t const *ldc; - int64_t const *ldd; + /// Leading dimension of B matrix + int64_t ldb; + + /// Leading dimension of C matrix + int64_t ldc; + + /// Leading dimension of D matrix + int64_t ldd; int batch_count; }; @@ -580,49 +556,98 @@ struct GemmArrayArguments { struct GemmPlanarComplexConfiguration { + GemmUniversalMode mode; gemm::GemmCoord problem_size; + int batch_count; - int64_t lda; - int64_t ldb; - int64_t ldc; - int64_t ldd; + int64_t lda_real; + int64_t lda_imag; - int64_t imag_stride_A; - int64_t imag_stride_B; - int64_t imag_stride_C; - int64_t imag_stride_D; + int64_t ldb_real; + int64_t ldb_imag; + + int64_t ldc_real; + int64_t ldc_imag; + + int64_t ldd_real; + int64_t ldd_imag; }; -using GemmPlanarComplexArgments = GemmArguments; +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArguments { + + void const *A_real; + void const *A_imag; + + void const *B_real; + void const *B_imag; + + void const *C_real; + void const *C_imag; + + void *D_real; + void *D_imag; + + void const *alpha; + void const *beta; + ScalarPointerMode pointer_mode; + + int64_t batch_stride_A_real; + int64_t batch_stride_A_imag; + + int64_t batch_stride_B_real; + int64_t batch_stride_B_imag; + + int64_t batch_stride_C_real; + int64_t batch_stride_C_imag; + + int64_t batch_stride_D_real; + int64_t batch_stride_D_imag; +}; ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Batched complex valued GEMM in which real and imaginary parts are separated by a stride -// -// OperationKind: Gemm -// GemmKind: Planar complex batched -// -struct GemmPlanarComplexBatchedConfiguration { +/// This is a special form of planar complex which loads pointers and problem size +/// from memory. +struct GemmPlanarComplexArrayConfiguration { gemm::GemmCoord problem_size; + int batch_count; - int64_t lda; - int64_t ldb; - int64_t ldc; - int64_t ldd; + int64_t lda_real; + int64_t lda_imag; - int64_t imag_stride_A; - int64_t imag_stride_B; - int64_t imag_stride_C; - int64_t imag_stride_D; + int64_t ldb_real; + int64_t ldb_imag; - int64_t batched_stride_A; - int64_t batched_stride_B; - int64_t batched_stride_C; - int64_t batched_stride_D; + int64_t ldc_real; + int64_t ldc_imag; + + int64_t ldd_real; + int64_t ldd_imag; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArrayArguments { + + int const *M; + int const *N; + int const *K; + + void const * const * A_real; + void const * const * A_imag; + void const * const * B_real; + void const * const * B_imag; + void const * const * C_real; + void const * const * C_imag; + void * const * D_real; + void * const * D_imag; + + void const * alpha; + void const * beta; + ScalarPointerMode pointer_mode; }; -using GemmPlanarComplexBatchedArguments = GemmArguments; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/include/cutlass/library/manifest.h b/tools/library/include/cutlass/library/manifest.h index 7746fc3d..eaa90b3a 100644 --- a/tools/library/include/cutlass/library/manifest.h +++ b/tools/library/include/cutlass/library/manifest.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -55,10 +55,14 @@ using OperationVector = std::vector>; class Manifest { private: + /// Operation provider + Provider provider_; + /// Global list of operations OperationVector operations_; public: + Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { } /// Top-level initialization Status initialize(); diff --git a/tools/library/include/cutlass/library/operation_table.h b/tools/library/include/cutlass/library/operation_table.h new file mode 100644 index 00000000..80ce1e15 --- /dev/null +++ b/tools/library/include/cutlass/library/operation_table.h @@ -0,0 +1,205 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + \file + \brief Defines a data structure in which a set of functionally equivalent library::Operation + instances may be queried. +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying functional behavior +struct GemmFunctionalKey { + + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + ComplexTransform transform_A; + NumericTypeID element_B; + LayoutTypeID layout_B; + ComplexTransform transform_B; + NumericTypeID element_C; + + // + // Methods + // + + inline + GemmFunctionalKey( + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + ComplexTransform transform_A = ComplexTransform::kNone, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + ComplexTransform transform_B = ComplexTransform::kNone, + NumericTypeID element_C = NumericTypeID::kF16 + ): + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + transform_A(transform_A), + element_B(element_B), + layout_B(layout_B), + transform_B(transform_B), + element_C(element_C) + { } + + inline + bool operator==(GemmFunctionalKey const &rhs) const { + return + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (transform_A == rhs.transform_A) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (transform_B == rhs.transform_B) && + (element_C == rhs.element_C); + } + + inline + bool operator!=(GemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for GemmFunctionalKey +struct GemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8 - shl)); + } + + inline + size_t operator()(GemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.element_compute)), 2) ^ + rotl(hash(int(key.element_scalar)), 3) ^ + rotl(hash(int(key.element_A)), 4) ^ + rotl(hash(int(key.layout_A)), 5) ^ + rotl(hash(int(key.transform_A)), 6) ^ + rotl(hash(int(key.element_B)), 7) ^ + rotl(hash(int(key.layout_B)), 8) ^ + rotl(hash(int(key.transform_B)), 9) ^ + rotl(hash(int(key.element_C)), 10); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes a partial ordering to search for GEMM operators +struct GemmPreferenceKey { + + int compute_capability; + int alignment; + + // + // Methods + // + + GemmPreferenceKey(): compute_capability(), alignment() { } + + GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } + + bool operator<(GemmPreferenceKey const &rhs) const { + return (compute_capability < rhs.compute_capability) || + ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); + } + + bool operator==(GemmPreferenceKey const &rhs) const { + return compute_capability == rhs.compute_capability; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Maps minimum compute capability onto a vector of possible operations +using GemmOperationVectorMap = std::map< + GemmPreferenceKey, + std::vector +>; + +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using GemmOperationFunctionalMap = std::unordered_map< + GemmFunctionalKey, + GemmOperationVectorMap, + GemmFunctionalKeyHasher +>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Table of cutlass::library::Operation instances +class OperationTable { +public: + + /// Map of all operations of type kGemm and gemm_kind of type kGemm + GemmOperationFunctionalMap gemm_operations; + + /// Map of all operations of type kGemm and gemm_kind of type kPlanarComplex + GemmOperationFunctionalMap gemm_planar_complex_operations; + + /// Map of all operations of type kGemm and gemm_kind of type kPlanarComplexArray + GemmOperationFunctionalMap gemm_planar_complex_array_operations; + +public: + + void append(Manifest const &manifest); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k); + diff --git a/tools/library/include/cutlass/library/singleton.h b/tools/library/include/cutlass/library/singleton.h new file mode 100644 index 00000000..591ad78f --- /dev/null +++ b/tools/library/include/cutlass/library/singleton.h @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/operation_table.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Singleton instance stores a Manifest and Operation table +class Singleton { +public: + + /// Manifest object + Manifest manifest; + + /// Operation table referencing the Manifest + OperationTable operation_table; + +public: + + Singleton(); + + static Singleton const &get(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/include/cutlass/library/util.h b/tools/library/include/cutlass/library/util.h new file mode 100644 index 00000000..5ff678e8 --- /dev/null +++ b/tools/library/include/cutlass/library/util.h @@ -0,0 +1,138 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief Utilities accompanying the CUTLASS library for interacting with Library types. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexical cast from string +template T from_string(std::string const &); + +/// Converts a Provider enumerant to a string +char const *to_string(Provider provider, bool pretty = false); + +/// Parses a Provider enumerant from a string +template <> Provider from_string(std::string const &str); + +/// Converts a NumericType enumerant to a string +char const *to_string(OperationKind type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> OperationKind from_string(std::string const &str); + +/// Converts a NumericType enumerant to a string +char const *to_string(NumericTypeID type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> NumericTypeID from_string(std::string const &str); + +/// Returns the size of a data type in bits +int sizeof_bits(NumericTypeID type); + +/// Returns true if the numeric type is a complex data type or false if real-valued. +bool is_complex_type(NumericTypeID type); + +/// Returns the real-valued type underlying a type (only different from 'type' if complex) +NumericTypeID get_real_type(NumericTypeID type); + +/// Returns true if numeric type is integer +bool is_integer_type(NumericTypeID type); + +/// Returns true if numeric type is signed +bool is_signed_type(NumericTypeID type); + +/// Returns true if numeric type is a signed integer +bool is_signed_integer(NumericTypeID type); + +/// returns true if numeric type is an unsigned integer +bool is_unsigned_integer(NumericTypeID type); + +/// Returns true if numeric type is floating-point type +bool is_float_type(NumericTypeID type); + +/// To string method for cutlass::Status +char const *to_string(Status status, bool pretty = false); + +/// Converts a LayoutTypeID enumerant to a string +char const *to_string(LayoutTypeID layout, bool pretty = false); + +/// Parses a LayoutType enumerant from a string +template <> LayoutTypeID from_string(std::string const &str); + +/// Returns the rank of a layout's stride base on the LayoutTypeID +int get_layout_stride_rank(LayoutTypeID layout_id); + +/// Converts a OpcodeClassID enumerant to a string +char const *to_string(OpcodeClassID type, bool pretty = false); + +/// Converts a OpcodeClassID enumerant from a string +template <> +OpcodeClassID from_string(std::string const &str); + +/// Converts a ComplexTransform enumerant to a string +char const *to_string(ComplexTransform type, bool pretty = false); + +/// Converts a ComplexTransform enumerant from a string +template <> +ComplexTransform from_string(std::string const &str); + +/// Lexical cast from int64_t to string +std::string lexical_cast(int64_t int_value); + +/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. +bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); + +/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. +std::string lexical_cast(std::vector &bytes, NumericTypeID type); + +/// Casts from a signed int64 to the destination type. Returns true if successful. +bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); + +/// Casts from an unsigned int64 to the destination type. Returns true if successful. +bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); + +/// Casts from a real value represented as a double to the destination type. Returns true if successful. +bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/scripts/gemm_operation.py b/tools/library/scripts/gemm_operation.py index 61658702..cc7d35d2 100644 --- a/tools/library/scripts/gemm_operation.py +++ b/tools/library/scripts/gemm_operation.py @@ -22,7 +22,9 @@ from library import * # class GemmOperation: # - def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue): + def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Cohort): + self.operation_kind = OperationKind.Gemm self.arch = arch self.tile_description = tile_description @@ -31,29 +33,75 @@ class GemmOperation: self.B = B self.C = C self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + # + def is_planar_complex(self): + return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + return ShortDataTypeNames[self.accumulator_type()] + # def core_name(self): ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + } + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: - inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) - else: - inst_shape = '' - return "%s%s%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], inst_shape, GemmKindNames[self.gemm_kind]) + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) # def extended_name(self): ''' Append data types if they differ from compute type. ''' - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - else: + if self.is_complex(): extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" extended_name = SubstituteTemplate(extended_name, { 'element_a': DataTypeNames[self.A.element], @@ -63,28 +111,32 @@ class GemmOperation: return extended_name + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + # def procedural_name(self): ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' - if self.tile_description.stages > 2: - threadblock = "%dx%d_%dx%d" % ( - self.tile_description.threadblock_shape[0], - self.tile_description.threadblock_shape[1], - self.tile_description.threadblock_shape[2], - self.tile_description.stages - ) - else: - threadblock = "%dx%d" % (self.tile_description.threadblock_shape[0], self.tile_description.threadblock_shape[1]) + threadblock = self.tile_description.procedural_name() opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + alignment = max([self.A.alignment, self.B.alignment, self.C.alignment]) + return SubstituteTemplate( - "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}", + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}", { 'opcode_class': opcode_class_name, 'extended_name': self.extended_name(), 'threadblock': threadblock, - 'layout': "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]), + 'layout': self.layout_name(), + 'alignment': "%d" % self.A.alignment, } ) @@ -104,7 +156,7 @@ class EmitGemmInstance: ''' Responsible for emitting a CUTLASS template definition''' def __init__(self): - self.template = """ + self.gemm_template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = cutlass::gemm::device::Gemm< ${element_a}, ${layout_a}, @@ -116,14 +168,45 @@ class EmitGemmInstance: cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - cutlass::epilogue::thread::LinearCombination< + ${epilogue_functor}< ${element_c}, ${epilogue_vector_length}, ${element_accumulator}, ${element_epilogue} >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, - ${stages} + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + self.gemm_complex_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::GemmComplex< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${transform_a}, + ${transform_b}, + ${math_operation} + ${residual} >; """ @@ -135,6 +218,8 @@ class EmitGemmInstance: epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + residual = '' + values = { 'operation_name': operation.procedural_name(), 'element_a': DataTypeTag[operation.A.element], @@ -143,7 +228,7 @@ class EmitGemmInstance: 'layout_b': LayoutTag[operation.B.layout], 'element_c': DataTypeTag[operation.C.element], 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], 'arch': "cutlass::arch::Sm%d" % operation.arch, 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), @@ -157,57 +242,72 @@ class EmitGemmInstance: 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), 'epilogue_vector_length': str(epilogue_vector_length), 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'stages': str(operation.tile_description.stages) + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual } - return SubstituteTemplate(self.template, values) + template = self.gemm_complex_template if operation.is_complex() else self.gemm_template + + return SubstituteTemplate(template, values) ################################################################################################### # -class EmitGemmBatchedInstance: +class EmitGemmPlanarComplexInstance: ''' Responsible for emitting a CUTLASS template definition''' def __init__(self): self.template = """ // Gemm operator ${operation_name} - using Operation_${operation_name} = cutlass::gemm::device::GemmBatched< - ${element_a}, ${layout_a}, - ${element_b}, ${layout_b}, - ${element_c}, ${layout_c}, + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, ${element_accumulator}, ${opcode_class}, ${arch}, cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - cutlass::epilogue::thread::LinearCombination< + cutlass::epilogue::thread::LinearCombinationPlanarComplex< ${element_c}, - ${epilogue_vector_length}, + ${alignment_c}, ${element_accumulator}, ${element_epilogue} >, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, ${stages}, - ${align_a}, - ${align_b} - >; + ${math_operator} + >::GemmKernel; + + struct ${operation_name} : public Operation_${operation_name} { }; """ def emit(self, operation): warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - #warp_shape[2] = operation.tile_description.math_instruction.instruction_shape[2] - warp_shape[2] = operation.tile_description.threadblock_shape[2] - epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] values = { 'operation_name': operation.procedural_name(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), 'element_c': DataTypeTag[operation.C.element], 'layout_c': LayoutTag[operation.C.layout], 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], @@ -222,139 +322,89 @@ class EmitGemmBatchedInstance: 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), + 'alignment_c': str(operation.C.alignment), 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' } return SubstituteTemplate(self.template, values) ################################################################################################### + # -# Generator functions for all layouts -# +class EmitGemmPlanarComplexArrayInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, + ${stages}, + ${math_operator} + >::GemmArrayKernel; + + struct ${operation_name} : public Operation_${operation_name} { }; +""" + + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } + + return SubstituteTemplate(self.template, values) + ################################################################################################### -# -def GenerateGemmSimt(gemm_kind, manifest, tile_descriptions, min_cc): - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - # for each tile configuration, emit a GEMM - for tile in tile_descriptions: - for layout in layouts: - - A = TensorDescription(tile.math_instruction.element_a, layout[0], 1) - B = TensorDescription(tile.math_instruction.element_b, layout[1], 1) - C = TensorDescription(tile.math_instruction.element_accumulator, layout[2], 1) - - manifest.append(GemmOperation(gemm_kind, 50, tile, A, B, C, tile.math_instruction.element_accumulator)) - -# -def GenerateGemmTensorOp(gemm_kind, manifest, tile_descriptions, min_cc, minimum_alignment = [128,]): - - # Canonical matrix layouts - canonical_layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - # Interleaved matrix layouts - interleaved_layouts = { - 8: [ - #(LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ], - 4: [ - #(LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - ] - } - - # for each tile configuration, emit a GEMM - for align in minimum_alignment: - for tile in tile_descriptions: - - min_input_size = min(DataTypeSize[tile.math_instruction.element_a], DataTypeSize[tile.math_instruction.element_a]) - - # If the data type is large enough, use canonical layouts. - if min_input_size >= 16: - layouts = canonical_layouts - else: - layouts = interleaved_layouts[min_input_size] - - for layout in layouts: - - # - output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ - if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ - else [tile.math_instruction.element_accumulator,] - - align_a = align // DataTypeSize[tile.math_instruction.element_a] - align_b = align // DataTypeSize[tile.math_instruction.element_b] - - - for output_type in output_types: - - rows_per_warp = 8 // tile.warp_count[1] - align_c = min(int(align / DataTypeSize[output_type]), tile.threadblock_shape[1] * rows_per_warp // 32) - - A = TensorDescription(tile.math_instruction.element_a, layout[0], align_a) - B = TensorDescription(tile.math_instruction.element_b, layout[1], align_b) - C = TensorDescription(output_type, layout[2], max(1, align_c)) - - element_epilogue = DataType.f32 if tile.math_instruction.element_accumulator == DataType.s32 \ - else tile.math_instruction.element_accumulator - - manifest.append(GemmOperation(gemm_kind, min_cc, tile, A, B, C, element_epilogue)) - - -# -def GenerateGemmWmmaTensorOp(gemm_kind, manifest, tile_descriptions, min_cc, minimum_alignment = [128,]): - - # Wmma supported matrix layouts - layouts = [ - (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), - ] - - # for each tile configuration, emit a GEMM - for align in minimum_alignment: - for tile in tile_descriptions: - for layout in layouts: - - # - output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ - if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ - else [tile.math_instruction.element_accumulator,] - - align_a = align // DataTypeSize[tile.math_instruction.element_a] - align_b = align // DataTypeSize[tile.math_instruction.element_b] - - - for output_type in output_types: - - rows_per_warp = 8 // tile.warp_count[1] - align_c = min(int(align / DataTypeSize[output_type]), tile.threadblock_shape[1] * rows_per_warp // 32) - - A = TensorDescription(tile.math_instruction.element_a, layout[0], align_a) - B = TensorDescription(tile.math_instruction.element_b, layout[1], align_b) - C = TensorDescription(output_type, layout[2], max(1, align_c)) - - element_epilogue = DataType.f32 if tile.math_instruction.element_accumulator == DataType.s32 \ - else tile.math_instruction.element_accumulator - - manifest.append(GemmOperation(gemm_kind, min_cc, tile, A, B, C, element_epilogue)) ################################################################################################### # @@ -369,21 +419,40 @@ class EmitGemmConfigurationLibrary: self.instance_emitter = { GemmKind.Gemm: EmitGemmInstance, - GemmKind.Batched: EmitGemmBatchedInstance + GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, + GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance } self.gemm_kind_wrappers = { GemmKind.Gemm: 'GemmOperation', - GemmKind.Batched: 'GemmBatchedOperation', + GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', + GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation' } self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" - self.instance_template = """ + self.instance_template = { + GemmKind.Gemm: """ ${compile_guard_start} manifest.append(new ${gemm_kind}("${operation_name}")); ${compile_guard_end} +""", + GemmKind.PlanarComplex: """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""", + GemmKind.PlanarComplexArray: """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} """ + } + self.header_template = """ /* Generated by gemm_operation.py - Do not edit. @@ -398,6 +467,14 @@ ${compile_guard_end} #include "library_internal.h" #include "gemm_operation.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + namespace cutlass { namespace library { @@ -421,9 +498,11 @@ void initialize_${configuration_name}(Manifest &manifest) { def __enter__(self): self.configuration_file = open(self.configuration_path, "w") - self.configuration_file.write(SubstituteTemplate(self.header_template, { - 'configuration_name': self.configuration_name - })) + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + self.operations = [] return self @@ -431,8 +510,10 @@ void initialize_${configuration_name}(Manifest &manifest) { emitter = self.instance_emitter[operation.gemm_kind]() self.operations.append(operation) - self.configuration_file.write(emitter.emit(operation)) - self.configuration_file.write(SubstituteTemplate(self.instance_template, { + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], { 'configuration_name': self.configuration_name, 'operation_name': operation.procedural_name(), 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind], @@ -443,6 +524,19 @@ void initialize_${configuration_name}(Manifest &manifest) { })) def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + self.configuration_file.write(self.epilogue_template) self.configuration_file.close() diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 6c09f180..4b1a483e 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -16,208 +16,913 @@ from gemm_operation import * # def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): - if semantic_ver_string == '': - cuda_version = [10, 2, 0] - else: - cuda_version = [int(x) for x in semantic_ver_string.split('.')] - return cuda_version >= [major, minor, patch] + # by default, use the latest CUDA Toolkit version + cuda_version = [10, 2, 82] + # Update cuda_version based on parsed string + if semantic_ver_string != '': + for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]): + if i < len(cuda_version): + cuda_version[i] = x + else: + cuda_version.append(x) + return cuda_version >= [major, minor, patch] + + +################################################################################################### ################################################################################################### # -def GenerateSM50(manifest, args): +def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Cohort): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.args.kernels == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GemmOperation(GemmKind.Gemm, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + gemm_kinds = [GemmKind.PlanarComplex, GemmKind.PlanarComplexArray] + + # by default, only generate the largest tile and largest alignment + if manifest.args.kernels == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for gemm_kind in gemm_kinds: + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + manifest.append(GemmOperation(gemm_kind, \ + tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue)) + return + +################################################################################################### +################################################################################################### +################################################################################################### + +# +def GenerateSM50_Simt(manifest, args): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + MathInstruction( \ + [1, 1, 1], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] min_cc = 50 max_cc = 1024 - stages = 2 - # single-precision - inst = MathInstruction([1, 1, 1], DataType.f32, DataType.f32, DataType.f32, OpcodeClass.Simt) - tile_descriptions = [ - TileDescription([128, 128, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 256, 8], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([256, 128, 8], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 128, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 32, 8], stages, [4, 1, 1], inst, min_cc, max_cc), - TileDescription([32, 128, 8], stages, [1, 4, 1], inst, min_cc, max_cc), - ] + alignment_constraints = [1,] - GenerateGemmSimt(GemmKind.Gemm, manifest, tile_descriptions, min_cc) - GenerateGemmSimt(GemmKind.Batched, manifest, tile_descriptions, min_cc) + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] - # double precision - inst = MathInstruction([1, 1, 1], DataType.f64, DataType.f64, DataType.f64, OpcodeClass.Simt) - tile_descriptions = [ - TileDescription([128, 128, 8], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 128, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 32, 8], stages, [4, 1, 1], inst, min_cc, max_cc), - TileDescription([32, 128, 8], stages, [1, 4, 1], inst, min_cc, max_cc), - ] + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# - GenerateGemmSimt(GemmKind.Gemm, manifest, tile_descriptions, min_cc) - GenerateGemmSimt(GemmKind.Batched, manifest, tile_descriptions, min_cc) +# +def GenerateSM50(manifest, args): + GenerateSM50_Simt(manifest, args) ################################################################################################### +################################################################################################### # -def GenerateSM60(manifest, args): +def GenerateSM60_Simt(manifest, args): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] min_cc = 60 max_cc = 1024 - stages = 2 - math_instructions = [ - MathInstruction([1, 1, 1], DataType.f16, DataType.f16, DataType.f16, OpcodeClass.Simt), - ] + alignment_constraints = [1,] - tile_descriptions = [] - - for inst in math_instructions: - tile_descriptions += [ - TileDescription([256, 256, 8], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 256, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 128, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 128, 8], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([32, 128, 8], stages, [1, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 32, 8], stages, [2, 1, 1], inst, min_cc, max_cc), + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), ] - GenerateGemmSimt(GemmKind.Gemm, manifest, tile_descriptions, min_cc) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# +# +def GenerateSM60(manifest, args): + GenerateSM60_Simt(manifest, args) + +################################################################################################### ################################################################################################### # -def GenerateSM61(manifest, args): +def GenerateSM61_Simt(manifest, args): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 4], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] min_cc = 61 max_cc = 1024 - stages = 2 - math_instructions = [ - MathInstruction([1, 1, 4], DataType.s8, DataType.s8, DataType.s32, OpcodeClass.Simt), - ] + alignment_constraints = [1,] - tile_descriptions = [] - - for inst in math_instructions: - tile_descriptions += [ - TileDescription([128, 256, 32], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([256, 128, 32], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 128, 32], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([64, 128, 32], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, 32], stages, [4, 1, 1], inst, min_cc, max_cc), - TileDescription([32, 128, 32], stages, [1, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 32, 32], stages, [2, 1, 1], inst, min_cc, max_cc), + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), ] - GenerateGemmSimt(GemmKind.Gemm, manifest, tile_descriptions, min_cc) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) +# + +# +def GenerateSM61(manifest, args): + GenerateSM61_Simt(manifest, args) + +################################################################################################### ################################################################################################### # -def GenerateSM70(manifest, args): +def GenerateSM70_TensorOp_884(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] min_cc = 70 max_cc = 75 - stages = 2 - k_groups = 8 - math_instructions = [ - MathInstruction([8, 8, 4], DataType.f16, DataType.f16, DataType.f16, OpcodeClass.TensorOp), - MathInstruction([8, 8, 4], DataType.f16, DataType.f16, DataType.f32, OpcodeClass.TensorOp), - ] + alignment_constraints = [8, 4, 2, 1] - tile_descriptions = [] - - for inst in math_instructions: - kblock = k_groups * inst.instruction_shape[2] - tile_descriptions += [ - TileDescription([256, 128, kblock], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 256, kblock], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([128, 128, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 128, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), ] - if CudaToolkitVersionSatisfies(args.cuda_version, 10, 1): - GenerateGemmTensorOp(GemmKind.Gemm, manifest, tile_descriptions, min_cc) - GenerateGemmTensorOp(GemmKind.Batched, manifest, tile_descriptions, min_cc) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) - # wmma tensor op SM70 Gemm kernels - stages = 2 - k_groups = 2 + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: - math_instructions = [ - MathInstruction([16, 16, 16], DataType.f16, DataType.f16, DataType.f16, OpcodeClass.WmmaTensorOp), - MathInstruction([16, 16, 16], DataType.f16, DataType.f16, DataType.f32, OpcodeClass.WmmaTensorOp), + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# +def GenerateSM70_PlanarComplexTensorOp_884(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), ] - tile_descriptions = [] + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] - for inst in math_instructions: - kblock = k_groups * inst.instruction_shape[2] - tile_descriptions += [ - TileDescription([128, 128, kblock], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([64, 128, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), + math_instructions = [ + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 75 + + alignment_constraints = [8, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), ] - GenerateGemmWmmaTensorOp(GemmKind.Gemm, manifest, tile_descriptions, min_cc) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + + +# +def GenerateSM70_WmmaTensorOp_161616(manifest, args): + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 16, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 16, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 1024 + + alignment_constraints = [8,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# +def GenerateSM70(manifest, args): + GenerateSM70_TensorOp_884(manifest, args) + GenerateSM70_PlanarComplexTensorOp_884(manifest, args) + #GenerateSM70_WmmaTensorOp_161616(manifest, args) + +################################################################################################### ################################################################################################### # -def GenerateSM75(manifest, args): +def GenerateSM75_TensorOp_1688(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] min_cc = 75 max_cc = 1024 - stages = 2 - k_groups = 4 - math_instructions = [ - MathInstruction([16, 8, 8], DataType.f16, DataType.f16, DataType.f16, OpcodeClass.TensorOp), - MathInstruction([16, 8, 8], DataType.f16, DataType.f16, DataType.f32, OpcodeClass.TensorOp), - MathInstruction([8, 8, 16], DataType.s8, DataType.s8, DataType.s32, OpcodeClass.TensorOp), - MathInstruction([8, 8, 32], DataType.s4, DataType.s4, DataType.s32, OpcodeClass.TensorOp) - ] + alignment_constraints = [8, 4, 2, 1] - tile_descriptions = [] - - for inst in math_instructions: - kblock = k_groups * inst.instruction_shape[2] - tile_descriptions += [ - TileDescription([256, 128, kblock], stages, [4, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 256, kblock], stages, [2, 4, 1], inst, min_cc, max_cc), - TileDescription([128, 128, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 128, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([128, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), - TileDescription([64, 64, kblock], stages, [2, 2, 1], inst, min_cc, max_cc), + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), ] - if CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): - GenerateGemmTensorOp(GemmKind.Gemm, manifest, tile_descriptions, min_cc) - GenerateGemmTensorOp(GemmKind.Batched, manifest, tile_descriptions, min_cc) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# + +# +def GenerateSM75_PlanarComplexTensorOp_1688(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [8, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + +# +def GenerateSM75_TensorOp_8816_TN(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.s32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# + +# +def GenerateSM75_TensorOp_8816_Interleaved(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + DataType.s8, + DataType.f32, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, \ + SwizzlingFunctor.Identity) + + for op in operations: + op.C.alignment = 8 + +# + +# +def GenerateSM75_TensorOp_8832_TN(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 32], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.s32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 8 + elif op.tile_description.threadblock_shape[1] == 64: + op.C.alignment = 8 + else: + op.C.alignment = 4 + +# + +# +def GenerateSM75_TensorOp_8832_Interleaved(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 32], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + DataType.s4, + DataType.f32, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, \ + SwizzlingFunctor.Identity) + + for op in operations: + op.C.alignment = 16 + +# + +# +def GenerateSM75_WmmaTensorOp_161616(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 10, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 16, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.f32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) +# + +# +def GenerateSM75(manifest, args): + GenerateSM75_TensorOp_1688(manifest, args) + GenerateSM75_PlanarComplexTensorOp_1688(manifest, args) + GenerateSM75_TensorOp_8816_TN(manifest, args) + GenerateSM75_TensorOp_8816_Interleaved(manifest, args) + GenerateSM75_TensorOp_8832_TN(manifest, args) + GenerateSM75_TensorOp_8832_Interleaved(manifest, args) + #GenerateSM75_WmmaTensorOp_161616(manifest, args) ################################################################################################### ################################################################################################### +################################################################################################### + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels") - parser.add_argument("--operations", default="gemm", help="Specifies the operation to generate (gemm, all)") + parser.add_argument("--operations", default="all", help="Specifies the operation to generate (gemm, all)") parser.add_argument("--build-dir", default=".", required=False, help="CUTLASS top-level build directory") parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory") parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") - parser.add_argument("--architectures", default='50 60 61 75', help="Target compute architectures") + parser.add_argument("--architectures", default='50;60;61;75', help="Target compute architectures") parser.add_argument("--kernels", default='', help='Comma delimited list to filter kernels by name.') - parser.add_argument("--cuda-version", default="10.2.0", help="Semantic version string of CUDA Toolkit") - + parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit") args = parser.parse_args() diff --git a/tools/library/scripts/library.py b/tools/library/scripts/library.py index b9ceeb4f..71f521e6 100644 --- a/tools/library/scripts/library.py +++ b/tools/library/scripts/library.py @@ -153,6 +153,68 @@ DataTypeSize = { ################################################################################################### +# +class ComplexTransform(enum.Enum): + none = enum.auto() + conj = enum.auto() + +# +ComplexTransformTag = { + ComplexTransform.none: 'cutlass::ComplexTransform::kNone', + ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', +} + +# +RealComplexBijection = [ + (DataType.f16, DataType.cf16), + (DataType.f32, DataType.cf32), + (DataType.f64, DataType.cf64), +] + +# +def is_complex(data_type): + for r, c in RealComplexBijection: + if data_type == c: + return True + return False + +# +def get_complex_from_real(real_type): + for r, c in RealComplexBijection: + if real_type == r: + return c + return DataType.invalid + +# +def get_real_from_complex(complex_type): + for r, c in RealComplexBijection: + if complex_type == c: + return r + return DataType.invalid + +# +class ComplexMultiplyOp(enum.Enum): + multiply_add = enum.auto() + gaussian = enum.auto() + +################################################################################################### + +# +class MathOperation(enum.Enum): + multiply_add = enum.auto() + multiply_add_saturate = enum.auto() + xor_popc = enum.auto() + multiply_add_complex = enum.auto() +# +MathOperationTag = { + MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', + MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', + MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', + MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', +} + +################################################################################################### + # class LayoutType(enum.Enum): ColumnMajor = enum.auto() @@ -182,6 +244,17 @@ LayoutTag = { LayoutType.TensorNCxHW64: 'cutlass::layout::TensorNCxHW64' } +# +TransposedLayout = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor, + LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, + LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, + LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, + LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, + LayoutType.TensorNHWC: LayoutType.TensorNHWC +} + # ShortLayoutTypeNames = { LayoutType.ColumnMajor: 'n', @@ -197,6 +270,14 @@ ShortLayoutTypeNames = { LayoutType.TensorNCxHW64: 'ncxhw64' } +# +ShortComplexLayoutNames = { + (LayoutType.ColumnMajor, ComplexTransform.none): 'n', + (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', + (LayoutType.RowMajor, ComplexTransform.none): 't', + (LayoutType.RowMajor, ComplexTransform.conj): 'h' +} + ################################################################################################### # @@ -244,9 +325,15 @@ ArchitectureNames = { # def SubstituteTemplate(template, values): text = template - for key, value in values.items(): - regex = "\\$\\{%s\\}" % key - text = re.sub(regex, value, text) + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext return text ################################################################################################### @@ -256,28 +343,52 @@ class GemmKind(enum.Enum): Gemm = enum.auto() Batched = enum.auto() Array = enum.auto() + Universal = enum.auto() PlanarComplex = enum.auto() - PlanarComplexBatched = enum.auto() + PlanarComplexArray = enum.auto() # GemmKindNames = { GemmKind.Gemm: "gemm", GemmKind.Batched: "gemm_batched", GemmKind.Array: "gemm_array", + GemmKind.Universal: "gemm_universal", GemmKind.PlanarComplex: "gemm_planar_complex", - GemmKind.PlanarComplexBatched: "gemm_planar_complex_batched", + GemmKind.PlanarComplexArray: "gemm_planar_complex_array", } +# +class EpilogueFunctor(enum.Enum): + LinearCombination = enum.auto() + LinearCombinationClamp = enum.auto() + +# +EpilogueFunctorTag = { + EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', + EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', +} + +# +class SwizzlingFunctor(enum.Enum): + Cohort = enum.auto() + Identity = enum.auto() + +# +SwizzlingFunctorTag = { + SwizzlingFunctor.Cohort: 'cutlass::gemm::threadblock::GemmCohortThreadblockSwizzle<${layout_a}, ${layout_b}>', + SwizzlingFunctor.Identity: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle', +} ################################################################################################### # class MathInstruction: - def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class): + def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add): self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b self.element_accumulator = element_accumulator self.opcode_class = opcode_class + self.math_operation = math_operation # @@ -292,16 +403,14 @@ class TileDescription: self.maximum_compute_capability = max_compute def procedural_name(self): - if self.stages == 2: - return "%dx%dx%d" % self.threadblock_shape - elif self.stages > 2: - return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) # class TensorDescription: - def __init__(self, element, layout, alignment = 1): + def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): self.element = element self.layout = layout self.alignment = alignment + self.complex_transform = complex_transform ################################################################################################### diff --git a/tools/library/scripts/manifest.py b/tools/library/scripts/manifest.py index 9ff69eb6..38182a1b 100644 --- a/tools/library/scripts/manifest.py +++ b/tools/library/scripts/manifest.py @@ -114,6 +114,16 @@ class Manifest: self.args = args self.compute_capabilities = [int(x) for x in args.architectures.split(';')] + if args.operations == 'all': + self.operations_enabled = [] + else: + + operations_list = [ + OperationKind.Gemm + ] + + self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] + if args.kernels == 'all': self.kernel_names = [] else: @@ -142,6 +152,16 @@ void initialize_all(Manifest &manifest) { } // namespace cutlass ''' + # + def _filter_string_matches(self, filter_string, haystack): + ''' Returns true if all substrings appear in the haystack in order''' + substrings = filter_string.split('*') + for sub in substrings: + idx = haystack.find(sub) + if idx < 0: + return False + haystack = haystack[idx + len(sub):] + return True # def filter(self, operation): @@ -159,6 +179,9 @@ void initialize_all(Manifest &manifest) { if not enabled: return False + if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled: + return False + # eliminate duplicates if operation.procedural_name() in self.operations_by_name.keys(): return False @@ -168,11 +191,10 @@ void initialize_all(Manifest &manifest) { name = operation.procedural_name() enabled = False for name_substr in self.kernel_names: - if name_substr in name: + if self._filter_string_matches(name_substr, name): enabled = True break - # todo: filter based on operation kind # todo: filter based on compute data type return enabled # @@ -255,10 +277,11 @@ void initialize_all(Manifest &manifest) { manifest_path = os.path.join(generated_path, "manifest.cmake") with open(manifest_path, "w") as manifest_file: - target_name = 'cutlass_lib' + target_name = 'cutlass_library_objs' target_text = SubstituteTemplate("""cutlass_target_sources( ${target_name} + BATCH_SOURCES ON PRIVATE """, { 'target_name': target_name}) diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index b00f8d2a..102c549a 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -29,8 +29,13 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" + #include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_complex.h" #include "cutlass/gemm/device/gemm_batched.h" +#include "cutlass/gemm/device/gemm_array.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/library/library.h" #include "library_internal.h" @@ -68,8 +73,10 @@ public: GemmOperationBase(char const *name = "unknown_gemm") { description_.name = name; + description_.provider = Provider::kCUTLASS; description_.kind = OperationKind::kGemm; - + description_.gemm_kind = GemmKind::kGemm; + description_.tile_description.threadblock_shape = make_Coord( Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, @@ -93,22 +100,23 @@ public: description_.tile_description.math_instruction.opcode_class = OpcodeClassMap::kId; + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + description_.tile_description.minimum_compute_capability = ArchMap::kMin; description_.tile_description.maximum_compute_capability = ArchMap::kMax; - - description_.gemm_kind = GemmKind::kGemm; description_.A = make_TensorDescription(Operator::kAlignmentA); description_.B = make_TensorDescription(Operator::kAlignmentB); description_.C = make_TensorDescription(Operator::kAlignmentC); description_.element_epilogue = NumericTypeMap::kId; - description_.split_k_mode = Operator::kSplitKSerial ? SplitKMode::kSerial : SplitKMode::kNone; - description_.transform_A = ComplexTransform::kNone; - description_.transform_B = ComplexTransform::kNone; + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; } /// Returns the description of the GEMM operation @@ -294,8 +302,24 @@ public: return op->run(stream); } -}; + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif + } +}; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -360,6 +384,7 @@ protected: *static_cast(arguments->alpha), *static_cast(arguments->beta) ); + operator_args.epilogue = params; } else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ @@ -491,6 +516,593 @@ public: } }; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmArrayOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + description_.gemm_kind = GemmKind::kArray; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmArrayConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + + operator_args.batch_count = configuration->batch_count; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmArrayArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +public: + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmArrayConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmArrayArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + return op->run(stream); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplex; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda_real = int(configuration->lda_real); + operator_args.lda_imag = int(configuration->lda_imag); + operator_args.ldb_real = int(configuration->ldb_real); + operator_args.ldb_imag = int(configuration->ldb_imag); + operator_args.ldc_real = int(configuration->ldc_real); + operator_args.ldc_imag = int(configuration->ldc_imag); + operator_args.ldd_real = int(configuration->ldd_real); + operator_args.ldd_imag = int(configuration->ldd_imag); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.batch_stride_A = arguments->batch_stride_A_real; + operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag; + operator_args.batch_stride_B = arguments->batch_stride_B_real; + operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag; + operator_args.batch_stride_C = arguments->batch_stride_C_real; + operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag; + operator_args.batch_stride_D = arguments->batch_stride_D_real; + operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag; + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexArrayOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplexArray; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda_real = int(configuration->lda_real); + operator_args.lda_imag = int(configuration->lda_imag); + operator_args.ldb_real = int(configuration->ldb_real); + operator_args.ldb_imag = int(configuration->ldb_imag); + operator_args.ldc_real = int(configuration->ldc_real); + operator_args.ldc_imag = int(configuration->ldc_imag); + operator_args.ldd_real = int(configuration->ldd_real); + operator_args.ldd_imag = int(configuration->ldd_imag); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.ptr_M = arguments->M; + operator_args.ptr_N = arguments->N; + operator_args.ptr_K = arguments->K; + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexArrayConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArrayArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu new file mode 100644 index 00000000..b2345932 --- /dev/null +++ b/tools/library/src/handle.cu @@ -0,0 +1,845 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Library handle. +*/ + +#include +#include + +#include "cutlass/library/handle.h" +#include "cutlass/library/singleton.h" +#include "cutlass/library/util.h" + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructor +Handle::Handle( + cudaStream_t stream, + size_t workspace_size +): + stream_(stream), + workspace_(nullptr), + workspace_size_(0), + scalar_pointer_mode_(ScalarPointerMode::kHost), + last_operation_(nullptr) { + + int device_idx = -1; + + cudaError_t error = cudaGetDevice(&device_idx); + if (error != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + error = cudaGetDeviceProperties(&device_, device_idx); + if (error != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + set_workspace_size(workspace_size); + + Singleton::get(); +} + +/// Destructor +Handle::~Handle() { + if (workspace_) { + + if (workspace_) { + cudaFree(workspace_); + } + + workspace_ = nullptr; + workspace_size_ = 0; + } +} + +/// Move constructor +Handle::Handle(Handle && handle) { + device_ = handle.device_; + workspace_size_ = handle.workspace_size_; + workspace_ = handle.workspace_; + stream_ = handle.stream_; + scalar_pointer_mode_ = handle.scalar_pointer_mode_; + + handle.workspace_ = nullptr; + handle.workspace_size_ = 0; +} + +/// Move assignment operator +Handle & Handle::operator=(Handle && handle) { + + device_ = handle.device_; + workspace_size_ = handle.workspace_size_; + workspace_ = handle.workspace_; + stream_ = handle.stream_; + scalar_pointer_mode_ = handle.scalar_pointer_mode_; + + handle.workspace_ = nullptr; + handle.workspace_size_ = 0; + + return *this; +} + +int Handle::compute_capability() const { + return device_.major * 10 + device_.minor; +} + +/// Sets the current CUDA stream +void Handle::set_stream(cudaStream_t stream) { + stream_ = stream; +} + +/// Gets the current CUDA stream +cudaStream_t Handle::get_stream() const { + return stream_; +} + +/// Gets the device workspace size +size_t Handle::get_workspace_size() const { + return workspace_size_; +} + +/// Gets a pointer to the device workspace allocation in Global Memory +void *Handle::get_workspace() const { + return workspace_; +} + +/// Sets the size of device workspace, invalidating previous calls to get_device_workspace() +void Handle::set_workspace_size(size_t bytes) { + if (bytes != workspace_size_) { + + if (workspace_) { + cudaFree(workspace_); + } + + workspace_ = nullptr; + workspace_size_ = bytes; + + if (workspace_size_) { + + cudaError_t error = cudaMalloc((void **)&workspace_, workspace_size_); + + if (error != cudaSuccess) { + throw std::runtime_error("Failed to allocate workspace"); + } + } + } + + if (workspace_) { + cudaError_t error = cudaMemset(workspace_, 0, workspace_size_); + + if (error != cudaSuccess) { + throw std::runtime_error("Failed to clear workspace"); + } + } +} + +/// Gets the scalar pointer mode +ScalarPointerMode Handle::get_scalar_pointer_mode() const { + return scalar_pointer_mode_; +} + +/// Sets the scalar pointer mode +void Handle::set_scalar_pointer_mode(ScalarPointerMode mode) { + scalar_pointer_mode_ = mode; +} + +/// Gets the last operation +Operation const *Handle::get_last_operation() const { + return last_operation_; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the maximum required alignment for each operator +static int maximum_alignment_requirement(GemmDescription const &desc) { + return std::max( + std::max(desc.A.alignment, desc.B.alignment), desc.C.alignment); +} + +/// Returns the largest alignment (in units of elements) the problem satisfies, starting from a +/// given upper limit. +static int gemm_problem_alignment( + int M, + int N, + int K, + NumericTypeID element_A, + void const *ptr_A, + int lda, + int64_t batch_stride_A, + NumericTypeID element_B, + void const *ptr_B, + int ldb, + int64_t batch_stride_B, + NumericTypeID element_C, + void const * ptr_C, + int ldc, + int64_t batch_stride_C, + void const * ptr_D, + int ldd, + int64_t batch_stride_D, + int max_alignment_in_bytes = 16 +) { + + void const *pointers[] = { + ptr_A, ptr_B, ptr_C, ptr_D + }; + + int64_t extents[] = { + M, N, K, lda, ldb, ldc, ldd, batch_stride_A, batch_stride_B, batch_stride_C, batch_stride_D + }; + + NumericTypeID elements[] = { + element_A, element_B, element_C + }; + + for (; max_alignment_in_bytes > 0; max_alignment_in_bytes /= 2) { + + bool satisfied = true; + + // Can pointers satisfy this? + for (void const *ptr : pointers) { + std::uintptr_t int_ptr = reinterpret_cast(ptr); + + if (int_ptr % max_alignment_in_bytes) { + satisfied = false; + break; + } + } + + if (!satisfied) { + continue; + } + + // Compute the maximum alignment based on element data types + int max_element_alignment = 0; + + for (NumericTypeID type_id : elements) { + int element_alignment = max_alignment_in_bytes * 8 / library::sizeof_bits(type_id); + max_element_alignment = std::max(max_element_alignment, element_alignment); + } + + // Can the problem size and leading dimensions satisfy this? + for (int64_t extent : extents) { + if (extent % max_element_alignment) { + satisfied = false; + break; + } + } + + if (!satisfied) { + continue; + } + + // Yes + return max_element_alignment; + } + + // No alignment satisfies this problem + return 0; +} + +/// Find the best kernel in descending order of preference. +static Operation const * find_gemm_operation( + GemmOperationFunctionalMap::const_iterator operators_it, + GemmPreferenceKey const preference_key) { + + auto cc_it = operators_it->second.upper_bound(preference_key); + + if (cc_it == operators_it->second.begin()) { + return nullptr; + } + + Operation const *operation = nullptr; + + // Search in descending order of compute capability + do { + --cc_it; + + // Search tile sizes in order, for now. + for (auto const * op : cc_it->second) { + + GemmDescription const &desc = static_cast(op->description()); + + int min_cc = desc.tile_description.minimum_compute_capability; + int max_cc = desc.tile_description.maximum_compute_capability; + + int op_alignment = maximum_alignment_requirement(desc); + + if ((min_cc <= preference_key.compute_capability) && + (preference_key.compute_capability <= max_cc) && + (op_alignment <= preference_key.alignment)) { + + operation = op; + break; + } + } + } while (!operation && cc_it != operators_it->second.begin()); + + return operation; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Executes a GEMM computation: D <= alpha * A*B + beta * C +Status Handle::gemm( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + + void const * ptr_A, /// Pointer to A matrix in Global Memory + int lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + + void const * ptr_B, /// Pointer to B matrix in Global Memory + int ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrices + + void const * ptr_C, /// Pointer to C matrix + int ldc, /// Leading dimension of C matrix + + void * ptr_D, /// Pointer to D matrix + int ldd /// Leading dimension of D matrix +) { + + // + // Find the operation + // + + GemmFunctionalKey key( + element_compute, + element_scalar, + element_A, + layout_A, + transform_A, + element_B, + layout_B, + transform_B, + element_C + ); + + auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); + + if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { + return cutlass::Status::kErrorNotSupported; + } + + if (operators_it->second.empty()) { + return cutlass::Status::kErrorNotSupported; + } + + // + // Compute the largest alignment restriction the kernel can satisfy. + // + + // Maximum alignment expectation among all kernels (in units of bytes) + int const kMaximumAlignmentSize = 16; + + int alignment = gemm_problem_alignment( + M, N, K, + element_A, ptr_A, lda, 0, + element_B, ptr_B, ldb, 0, + element_C, ptr_C, ldc, 0, + ptr_D, ldd, 0, kMaximumAlignmentSize + ); + + // + // Find the best kernel in descending order of preference. + // + + GemmPreferenceKey preference_key(compute_capability(), alignment); + + Operation const *operation = find_gemm_operation(operators_it, preference_key); + + if (!operation) { + return cutlass::Status::kErrorNotSupported; + } + + last_operation_ = operation; + + // + // Configure operation + // + + GemmConfiguration configuration{ + {M, N, K}, + lda, + ldb, + ldc, + ldd, + 1 + }; + + // Query host work space size + uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); + + if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + char host_workspace[kHostWorkspaceSize]; + + // Query device workspace size + uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); + + if (uint64_t(workspace_size_) < device_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + // Initialize host and device workspaces + Status status = operation->initialize( + &configuration, + host_workspace, + workspace_, + stream_); + + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Run the operator + GemmArguments arguments{ + ptr_A, + ptr_B, + ptr_C, + ptr_D, + alpha, + beta, + scalar_pointer_mode_ + }; + + return operation->run(&arguments, host_workspace, workspace_, stream_); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Planar complex GEMM +Status Handle::gemm_planar_complex( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * ptr_A_real, /// Pointer to real part of A matrix + void const * ptr_A_imag, /// Pointer to imaginary part of A matrix + int lda_real, /// Leading dimension of real part of A matrix + int lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * ptr_B_real, /// Pointer to real part of B matrix + void const * ptr_B_imag, /// Pointer to imaginary part of B matrix + int ldb_real, /// Leading dimension of real part of B matrix + int ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * ptr_C_real, /// Pointer to real part of C matrix + void const * ptr_C_imag, /// Pointer to imaginary part of C matrix + int ldc_real, /// Leading dimension of real part of C matrix + int ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * ptr_D_real, /// Pointer to real part of D matrix + void * ptr_D_imag, /// Pointer to imaginary part of D matrix + int ldd_real, /// Leading dimension of real part of D matrix + int ldd_imag, /// Leading dimension of imaginary part of D matrix + + int batch_count, /// Number of batched GEMMs to execute + + int64_t batch_stride_A_real, + int64_t batch_stride_A_imag, + + int64_t batch_stride_B_real, + int64_t batch_stride_B_imag, + + int64_t batch_stride_C_real, + int64_t batch_stride_C_imag, + + int64_t batch_stride_D_real, + int64_t batch_stride_D_imag +) { + + // + // Find the operation + // + + GemmFunctionalKey key( + element_compute, + element_scalar, + element_A, + layout_A, + transform_A, + element_B, + layout_B, + transform_B, + element_C + ); + + auto operators_it = Singleton::get().operation_table.gemm_planar_complex_operations.find(key); + + if (operators_it == Singleton::get().operation_table.gemm_planar_complex_operations.end()) { + return cutlass::Status::kErrorNotSupported; + } + + if (operators_it->second.empty()) { + return cutlass::Status::kErrorNotSupported; + } + + // + // Compute the largest alignment restriction the kernel can satisfy. + // + + // Maximum alignment expectation among all kernels (in units of bytes) + int const kMaximumAlignmentSize = 16; + + int alignment = std::max( + gemm_problem_alignment( + M, N, K, + element_A, ptr_A_real, lda_real, batch_stride_A_real, + element_B, ptr_B_real, ldb_real, batch_stride_B_real, + element_C, ptr_C_real, ldc_real, batch_stride_C_real, + ptr_D_real, ldd_real, batch_stride_D_real, kMaximumAlignmentSize + ), + gemm_problem_alignment( + M, N, K, + element_A, ptr_A_imag, lda_imag, batch_stride_A_imag, + element_B, ptr_B_imag, ldb_imag, batch_stride_B_imag, + element_C, ptr_C_imag, ldc_imag, batch_stride_C_imag, + ptr_D_imag, ldd_imag, batch_stride_D_imag, kMaximumAlignmentSize + ) + ); + + // + // Find the best kernel in descending order of preference. + // + + GemmPreferenceKey preference_key(compute_capability(), alignment); + + Operation const *operation = find_gemm_operation(operators_it, preference_key); + + if (!operation) { + return cutlass::Status::kErrorNotSupported; + } + + last_operation_ = operation; + + // + // Configure operation + // + + GemmPlanarComplexConfiguration configuration{ + GemmUniversalMode::kBatched, + {M, N, K}, + batch_count, + lda_real, + lda_imag, + ldb_real, + ldb_imag, + ldc_real, + ldc_imag, + ldd_real, + ldd_imag + }; + + // Query host work space size + uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); + + if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + char host_workspace[kHostWorkspaceSize]; + + // Query device workspace size + uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); + + if (uint64_t(workspace_size_) < device_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + // Initialize host and device workspaces + Status status = operation->initialize( + &configuration, + host_workspace, + workspace_, + stream_); + + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Run the operator + GemmPlanarComplexArguments arguments{ + ptr_A_real, + ptr_A_imag, + ptr_B_real, + ptr_B_imag, + ptr_C_real, + ptr_C_imag, + ptr_D_real, + ptr_D_imag, + alpha, + beta, + scalar_pointer_mode_, + batch_stride_A_real, + batch_stride_A_imag, + batch_stride_B_real, + batch_stride_B_imag, + batch_stride_C_real, + batch_stride_C_imag, + batch_stride_D_real, + batch_stride_D_imag + }; + + return operation->run(&arguments, host_workspace, workspace_, stream_); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Planar complex batched GEMM loading pointers from arrays in global memory +Status Handle::gemm_planar_complex_array( + + int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) + int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) + int expected_K, /// Expected GEMM K dimension + int batch_count, /// Number of independent GEMM computations to execute + + int const *M, /// Array containing the GEMM M dimension for each batch index + int const *N, /// Array containing the GEMM N dimension for each batch index + int const *K, /// Array containing the GEMM K dimension for each batch index + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices + void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices + + int lda_real, /// Leading dimension of real part of A matrix + int lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices + void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices + + int ldb_real, /// Leading dimension of real part of B matrix + int ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices + void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices + + int ldc_real, /// Leading dimension of real part of C matrix + int ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices + void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices + + int ldd_real, /// Leading dimension of real part of D matrix + int ldd_imag /// Leading dimension of imaginary part of D matrix +) { + + // + // Find the operation + // + + GemmFunctionalKey key( + element_compute, + element_scalar, + element_A, + layout_A, + transform_A, + element_B, + layout_B, + transform_B, + element_C + ); + + auto operators_it = Singleton::get().operation_table.gemm_planar_complex_array_operations.find(key); + + if (operators_it == Singleton::get().operation_table.gemm_planar_complex_array_operations.end()) { + return cutlass::Status::kErrorNotSupported; + } + + if (operators_it->second.empty()) { + return cutlass::Status::kErrorNotSupported; + } + + // + // Compute the largest alignment restriction the kernel can satisfy. + // + + // Maximum alignment expectation among all kernels (in units of bytes) + int const kMaximumAlignmentSize = 16; + + int alignment = std::max( + gemm_problem_alignment( + expected_M, expected_N, expected_K, + element_A, nullptr, lda_real, 0, + element_B, nullptr, ldb_real, 0, + element_C, nullptr, ldc_real, 0, + nullptr, ldd_real, 0, kMaximumAlignmentSize + ), + gemm_problem_alignment( + expected_M, expected_N, expected_K, + element_A, nullptr, lda_imag, 0, + element_B, nullptr, ldb_imag, 0, + element_C, nullptr, ldc_imag, 0, + nullptr, ldd_imag, 0, kMaximumAlignmentSize + ) + ); + + // + // Find the best kernel in descending order of preference. + // + + GemmPreferenceKey preference_key(compute_capability(), alignment); + + Operation const *operation = find_gemm_operation(operators_it, preference_key); + + if (!operation) { + return cutlass::Status::kErrorNotSupported; + } + + last_operation_ = operation; + + // + // Configure operation + // + + GemmPlanarComplexArrayConfiguration configuration{ + {expected_M, expected_N, expected_K}, + batch_count, + lda_real, + lda_imag, + ldb_real, + ldb_imag, + ldc_real, + ldc_imag, + ldd_real, + ldd_imag + }; + + // Query host work space size + uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); + + if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + char host_workspace[kHostWorkspaceSize]; + + // Query device workspace size + uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); + + if (uint64_t(workspace_size_) < device_workspace_size_needed) { + return cutlass::Status::kErrorNotSupported; + } + + // Initialize host and device workspaces + Status status = operation->initialize( + &configuration, + host_workspace, + workspace_, + stream_); + + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Run the operator + GemmPlanarComplexArrayArguments arguments{ + M, N, K, + ptr_A_real, + ptr_A_imag, + ptr_B_real, + ptr_B_imag, + ptr_C_real, + ptr_C_imag, + ptr_D_real, + ptr_D_imag, + alpha, + beta, + scalar_pointer_mode_ + }; + + return operation->run(&arguments, host_workspace, workspace_, stream_); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index 5feff5fb..252d474e 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -57,6 +57,10 @@ namespace library { template struct NumericTypeMap; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kB1; +}; + template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kS4; }; @@ -123,6 +127,28 @@ template <> struct NumericTypeMap > { ///////////////////////////////////////////////////////////////////////////////////////////////// +template struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kInvalid; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAdd; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kXorPopc; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + template struct LayoutMap; template <> struct LayoutMap { @@ -133,6 +159,34 @@ template <> struct LayoutMap { static LayoutTypeID const kId = LayoutTypeID::kRowMajor; }; +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// template struct OpcodeClassMap; @@ -148,6 +202,19 @@ template <> struct OpcodeClassMap { template <> struct OpcodeClassMap { static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; }; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ComplexTransformMap; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone; +}; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// template struct ArchMap; diff --git a/tools/library/src/manifest.cpp b/tools/library/src/manifest.cpp index 159bf3f0..ca6d1781 100644 --- a/tools/library/src/manifest.cpp +++ b/tools/library/src/manifest.cpp @@ -1,6 +1,4 @@ -/*! - -*//*************************************************************************************************** +/*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted @@ -37,11 +35,12 @@ namespace cutlass { namespace library { -/////////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////////////// -void initialize_all(Manifest &manifest); +// init and insert all cutlass op in manifest object (procedurally generated using generator.py) +void initialize_all(Manifest &manifest); -/////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////////////// /// Top-level initialization Status Manifest::initialize() { @@ -50,7 +49,13 @@ Status Manifest::initialize() { operations_.clear(); } - initialize_all(*this); + switch(provider_) { + case Provider::kCUTLASS: + initialize_all(*this); break; + + default: + break; + } return Status::kSuccess; } diff --git a/tools/library/src/operation_table.cu b/tools/library/src/operation_table.cu new file mode 100644 index 00000000..8fb0fe63 --- /dev/null +++ b/tools/library/src/operation_table.cu @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + \file + \brief Defines a data structure in which a set of functionally equivalent library::Operation + instances may be queried. +*/ + +#include + +#include "cutlass/library/library.h" +#include "cutlass/library/operation_table.h" +#include "cutlass/library/util.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) { + + out << "{\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " transform_A: " << to_string(k.transform_A) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " transform_B: " << to_string(k.transform_B) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +void OperationTable::append(Manifest const &manifest) { + + // Insert operations into appropriate data structure + for (auto const & operation : manifest) { + + OperationDescription const &desc = operation->description(); + + if (desc.kind == OperationKind::kGemm) { + GemmDescription const &gemm_desc = static_cast(desc); + + if (gemm_desc.gemm_kind == GemmKind::kGemm) { + + GemmFunctionalKey functional_key( + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + gemm_desc.A.element, + gemm_desc.A.layout, + gemm_desc.transform_A, + gemm_desc.B.element, + gemm_desc.B.layout, + gemm_desc.transform_B, + gemm_desc.C.element + ); + + Operation const *op = operation.get(); + + int cc = gemm_desc.tile_description.minimum_compute_capability; + + int alignment = std::max(std::max( + gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment); + + GemmPreferenceKey preference_key(cc, alignment); + + gemm_operations[functional_key][preference_key].push_back(op); + } + else if (gemm_desc.gemm_kind == GemmKind::kPlanarComplex) { + + GemmFunctionalKey functional_key( + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + gemm_desc.A.element, + gemm_desc.A.layout, + gemm_desc.transform_A, + gemm_desc.B.element, + gemm_desc.B.layout, + gemm_desc.transform_B, + gemm_desc.C.element + ); + + Operation const *op = operation.get(); + + int cc = gemm_desc.tile_description.minimum_compute_capability; + + int alignment = std::max(std::max( + gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment); + + GemmPreferenceKey preference_key(cc, alignment); + + gemm_planar_complex_operations[functional_key][preference_key].push_back(op); + } + else if (gemm_desc.gemm_kind == GemmKind::kPlanarComplexArray) { + + GemmFunctionalKey functional_key( + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + gemm_desc.A.element, + gemm_desc.A.layout, + gemm_desc.transform_A, + gemm_desc.B.element, + gemm_desc.B.layout, + gemm_desc.transform_B, + gemm_desc.C.element + ); + + Operation const *op = operation.get(); + + int cc = gemm_desc.tile_description.minimum_compute_capability; + + int alignment = std::max(std::max( + gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment); + + GemmPreferenceKey preference_key(cc, alignment); + + gemm_planar_complex_array_operations[functional_key][preference_key].push_back(op); + } + } + } + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/singleton.cu b/tools/library/src/singleton.cu new file mode 100644 index 00000000..642ac61a --- /dev/null +++ b/tools/library/src/singleton.cu @@ -0,0 +1,63 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/operation_table.h" +#include "cutlass/library/singleton.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static std::unique_ptr instance; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +Singleton::Singleton() { + + manifest.initialize(); + + operation_table.append(manifest); +} + +Singleton const & Singleton::get() { + if (!instance.get()) { + instance.reset(new Singleton); + } + return *instance.get(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/library.cu b/tools/library/src/util.cu similarity index 86% rename from tools/library/src/library.cu rename to tools/library/src/util.cu index 92f87c61..12757292 100644 --- a/tools/library/src/library.cu +++ b/tools/library/src/util.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -25,17 +25,65 @@ #include #include - #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/complex.h" -#include "cutlass/library/library.h" #include "cutlass/layout/matrix.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" + namespace cutlass { namespace library { +///////////////////////////////////////////////////////////////////////////////////////////////// + +static struct { + char const *text; + char const *pretty; + Provider enumerant; +} +Provider_enumerants[] = { + {"cutlass", "CUTLASS", Provider::kCUTLASS}, + {"host", "reference_host", Provider::kReferenceHost}, + {"device", "reference_device", Provider::kReferenceDevice}, + {"cublas", "cuBLAS", Provider::kCUBLAS}, +}; + +/// Converts a Provider enumerant to a string +char const *to_string(Provider provider, bool pretty) { + + for (auto const & possible : Provider_enumerants) { + if (provider == possible.enumerant) { + if (pretty) { + return possible.pretty; + } + else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Parses a Provider enumerant from a string +template <> +Provider from_string(std::string const &str) { + + for (auto const & possible : Provider_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return Provider::kInvalid; +} + + + ///////////////////////////////////////////////////////////////////////////////////////////////// static struct { @@ -44,7 +92,7 @@ static struct { OperationKind enumerant; } OperationKind_enumerants[] = { - {"gemm", "Gemm", OperationKind::kGemm}, + {"gemm", "Gemm", OperationKind::kGemm}, }; /// Converts a Status enumerant to a string @@ -203,6 +251,9 @@ int sizeof_bits(NumericTypeID type) { case NumericTypeID::kF16: return 16; case NumericTypeID::kF32: return 32; case NumericTypeID::kF64: return 64; + case NumericTypeID::kCF16: return 32; + case NumericTypeID::kCF32: return 64; + case NumericTypeID::kCF64: return 128; case NumericTypeID::kS4: return 4; case NumericTypeID::kS8: return 8; case NumericTypeID::kS16: return 16; @@ -291,6 +342,9 @@ bool is_float_type(NumericTypeID type) { case NumericTypeID::kF16: return true; case NumericTypeID::kF32: return true; case NumericTypeID::kF64: return true; + case NumericTypeID::kCF16: return true; + case NumericTypeID::kCF32: return true; + case NumericTypeID::kCF64: return true; default: break; } return false; @@ -309,8 +363,18 @@ layout_aliases[] = { {LayoutTypeID::kColumnMajor, "column"}, {LayoutTypeID::kColumnMajor, "col"}, {LayoutTypeID::kColumnMajor, "n"}, + + {LayoutTypeID::kColumnMajorInterleavedK16, "nk16"}, + {LayoutTypeID::kRowMajorInterleavedK16, "tk16"}, + + {LayoutTypeID::kColumnMajorInterleavedK32, "nk32"}, + {LayoutTypeID::kRowMajorInterleavedK32, "tk32"}, + + {LayoutTypeID::kColumnMajorInterleavedK64, "nk64"}, + {LayoutTypeID::kRowMajorInterleavedK64, "tk64"}, + {LayoutTypeID::kTensorNCHW, "nchw"}, - {LayoutTypeID::kTensorNHWC, "packed_nhwc"}, + {LayoutTypeID::kTensorNHWC, "nhwc"}, {LayoutTypeID::kUnknown, "*"}, {LayoutTypeID::kInvalid, nullptr} }; @@ -344,7 +408,12 @@ int get_layout_stride_rank(LayoutTypeID layout_id) { case LayoutTypeID::kColumnMajorInterleavedK4: case LayoutTypeID::kRowMajorInterleavedK4: case LayoutTypeID::kColumnMajorInterleavedK16: - case LayoutTypeID::kRowMajorInterleavedK16: return 1; + case LayoutTypeID::kRowMajorInterleavedK16: + case LayoutTypeID::kColumnMajorInterleavedK32: + case LayoutTypeID::kRowMajorInterleavedK32: + case LayoutTypeID::kColumnMajorInterleavedK64: + case LayoutTypeID::kRowMajorInterleavedK64: + return 1; case LayoutTypeID::kTensorNCHW: case LayoutTypeID::kTensorNHWC: return 3; default : throw std::runtime_error("Unsupported LayoutTypeID in LayoutType::get_stride_rank"); @@ -396,8 +465,51 @@ OpcodeClassID from_string(std::string const &str) { return OpcodeClassID::kInvalid; } -/////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// +static struct { + char const *text; + char const *pretty; + ComplexTransform enumerant; +} +ComplexTransform_enumerants[] = { + {"n", "none", ComplexTransform::kNone}, + {"c", "conj", ComplexTransform::kConjugate} +}; + +/// Converts a ComplexTransform enumerant to a string +char const *to_string(ComplexTransform type, bool pretty) { + + for (auto const & possible : ComplexTransform_enumerants) { + if (type == possible.enumerant) { + if (pretty) { + return possible.pretty; + } + else { + return possible.text; + } + } + } + + return pretty ? "Invalid" : "invalid"; +} + +/// Converts a ComplexTransform enumerant from a string +template <> +ComplexTransform from_string(std::string const &str) { + + for (auto const & possible : ComplexTransform_enumerants) { + if ((str.compare(possible.text) == 0) || + (str.compare(possible.pretty) == 0)) { + return possible.enumerant; + } + } + + return ComplexTransform::kInvalid; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// /// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str) { int size_bytes = sizeof_bits(type) / 8; @@ -574,25 +686,36 @@ std::string lexical_cast(std::vector &bytes, NumericTypeID type) { break; case NumericTypeID::kCF16: { - std::complex tmp; - cutlass::complex const *x = reinterpret_cast const *>(bytes.data()); - tmp.real(x->real()); - tmp.imag(x->imag()); + ss << float(x->real()); - ss << tmp; + if (x->imag() != cutlass::half_t()) { + ss << "+i" << float(x->imag()); + } } break; case NumericTypeID::kCF32: { - ss << *reinterpret_cast*>(bytes.data()); + cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); + + ss << x->real(); + + if (x->imag() != float()) { + ss << "+i" << x->imag(); + } } break; case NumericTypeID::kCF64: { - ss << *reinterpret_cast*>(bytes.data()); + cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); + + ss << x->real(); + + if (x->imag() != double()) { + ss << "+i" << x->imag(); + } } break; default: diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index b3a06900..6e822c68 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -33,7 +33,7 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES src/gpu_timer.cpp src/device_allocation.cu src/device_context.cu - src/cublas_helpers.cpp + src/cublas_helpers.cpp src/problem_space.cpp src/operation_profiler.cu src/gemm_operation_profiler.cu @@ -54,11 +54,11 @@ set_target_properties(cutlass_profiler PROPERTIES EXPORT_NAME profiler) # Include paths # -target_include_directories(cutlass_profiler +target_include_directories( + cutlass_profiler PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src # Source directory - ../../tools/util/include -) + ) # # Library dependencies @@ -68,8 +68,8 @@ target_link_libraries( cutlass_profiler PRIVATE cutlass_lib - $<$:cublas> - gtest + cutlass_tools_util_includes + $<$:nvidia::cublas> cudart ) diff --git a/tools/profiler/src/cublas_helpers.cpp b/tools/profiler/src/cublas_helpers.cpp index 973dc44c..5e5e2cb0 100644 --- a/tools/profiler/src/cublas_helpers.cpp +++ b/tools/profiler/src/cublas_helpers.cpp @@ -39,14 +39,14 @@ namespace profiler { /// Converts a cuBLAS status to cutlass::Status Status get_cutlass_status(cublasStatus_t cublas) { - if (cublas == CUBLAS_STATUS_SUCCESS) { - return Status::kSuccess; - } - else if (cublas == CUBLAS_STATUS_INVALID_VALUE) { - return Status::kErrorInvalidProblem; - } - if (cublas == CUBLAS_STATUS_NOT_SUPPORTED) { - return Status::kErrorNotSupported; + switch (cublas) { + case CUBLAS_STATUS_SUCCESS: + return Status::kSuccess; + case CUBLAS_STATUS_INVALID_VALUE: + return Status::kErrorInvalidProblem; + case CUBLAS_STATUS_NOT_SUPPORTED: + return Status::kErrorNotSupported; + default: break; } return Status::kErrorInternal; } @@ -145,6 +145,13 @@ Status cublas_satisfies(library::GemmDescription const &desc) { return Status::kErrorNotSupported; } + // output type S4 and S8 not supported in cuBLAS + if (desc.C.element == library::NumericTypeID::kS4 || + desc.C.element == library::NumericTypeID::kS8) { + + return Status::kErrorNotSupported; + } + return Status::kSuccess; } diff --git a/tools/profiler/src/cublas_helpers.h b/tools/profiler/src/cublas_helpers.h index 6bb2f4e9..0ade0961 100644 --- a/tools/profiler/src/cublas_helpers.h +++ b/tools/profiler/src/cublas_helpers.h @@ -33,7 +33,7 @@ #include "cutlass/cutlass.h" #include "cutlass/library/library.h" - +#include "options.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -86,6 +86,161 @@ public: ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +/// Selects one or more cuBLAS algorithms. +static void select_cublas_algorithms( + std::vector &algorithms, + Options const &options, + library::GemmDescription const &op_desc) { + + library::OpcodeClassID const & opcode_class = + op_desc.tile_description.math_instruction.opcode_class; + + switch (options.library.algorithm_mode) { + case AlgorithmMode::kMatching: + { + algorithms.push_back(get_cublas_gemm_algo( + op_desc.tile_description.threadblock_shape.m(), + op_desc.tile_description.threadblock_shape.n(), + op_desc.tile_description.threadblock_shape.k(), + opcode_class)); + break; + } + + case AlgorithmMode::kBest: + { + // Choose first enumerated mode. If none are enumerated, choose based on opcode class + // and evaluate all of them. + + if (options.library.algorithms.empty()) { + // Enumerate all algorithms + if (opcode_class == library::OpcodeClassID::kSimt) { + + for (int algo = CUBLAS_GEMM_DEFAULT; + algo <= CUBLAS_GEMM_ALGO23; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + else { + + for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + } + else { + // Use the listed algorithms + algorithms.reserve(options.library.algorithms.size()); + + for (int algo : options.library.algorithms) { + algorithms.push_back(reinterpret_cast(algo)); + } + } + + break; + } + + case AlgorithmMode::kDefault: + { + + // Use the library's default algorithm + algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? + CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + break; + } + default: + { + break; + } + } +} + +/// Dispatcher to cublasGemmEx() +struct cublasGemmExDispatcher { + + // + // Data members + // + library::GemmConfiguration configuration; + library::GemmArguments arguments; + + // cublass-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasOperation_t trans_B; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_type; + cublasGemmAlgo_t algo; + Status status; + + // + // Methods + // + + cublasGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmConfiguration configuration_, + library::GemmArguments arguments_, + cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT + ): + configuration(configuration_), arguments(arguments_), algo(algorithm), status(Status::kSuccess) { + + trans_A = get_cublas_transpose_operation(op_desc.A.layout); + trans_B = get_cublas_transpose_operation(op_desc.B.layout); + + bool good = true; + good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); + good = (good && get_cublas_datatype(data_type_B, op_desc.B.element)); + good = (good && get_cublas_datatype(data_type_C, op_desc.C.element)); + + good = (good && get_cublas_datatype( + compute_type, + op_desc.tile_description.math_instruction.element_accumulator)); + + if (!good) { + status = Status::kErrorNotSupported; + } + } + + /// Executes GEMM using these arguments + cublasStatus_t operator()(cublasHandle_t handle) { + + return cublasGemmEx( + handle, + trans_A, + trans_B, + configuration.problem_size.m(), + configuration.problem_size.n(), + configuration.problem_size.k(), + arguments.alpha, + arguments.A, + data_type_A, + int(configuration.lda), + arguments.B, + data_type_B, + int(configuration.ldb), + arguments.beta, + arguments.D, + data_type_C, + int(configuration.ldc), + compute_type, + algo + ); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + } // namespace profiler } // namespace cutlass diff --git a/tools/profiler/src/cutlass_profiler.cu b/tools/profiler/src/cutlass_profiler.cu index e1664de7..b36f897b 100644 --- a/tools/profiler/src/cutlass_profiler.cu +++ b/tools/profiler/src/cutlass_profiler.cu @@ -29,14 +29,9 @@ #include #include -// CUTLASS Library includes -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - // Profiler includes #include "cutlass_profiler.h" #include "gemm_operation_profiler.h" - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -49,7 +44,8 @@ CutlassProfiler::CutlassProfiler( ): options_(options) { - operation_profilers_.emplace_back(new GemmOperationProfiler); + operation_profilers_.emplace_back(new GemmOperationProfiler); + } CutlassProfiler::~CutlassProfiler() { @@ -112,7 +108,7 @@ void CutlassProfiler::enumerate_() { /// Profiles all operations int CutlassProfiler::profile_() { - library::Manifest manifest; + library::Manifest manifest(library::Provider::kCUTLASS); Status status = manifest.initialize(); if (status != Status::kSuccess) { @@ -165,7 +161,8 @@ void CutlassProfiler::print_usage_(std::ostream &out) { } out << "\n\nFor details about a particular function, specify the function name with --help.\n\nExample:\n\n" - << " $ cutlass_profiler --operation=Gemm --help\n\n"; + << " $ cutlass_profiler --operation=Gemm --help\n\n" + ; } /// Prints usage diff --git a/tools/profiler/src/cutlass_profiler.h b/tools/profiler/src/cutlass_profiler.h index 09d24011..eda24c5b 100644 --- a/tools/profiler/src/cutlass_profiler.h +++ b/tools/profiler/src/cutlass_profiler.h @@ -27,6 +27,9 @@ */ #pragma once +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" #include "options.h" #include "operation_profiler.h" diff --git a/tools/profiler/src/debug.h b/tools/profiler/src/debug.h index 3aaf3bb4..8aad2ee9 100644 --- a/tools/profiler/src/debug.h +++ b/tools/profiler/src/debug.h @@ -30,11 +30,11 @@ #include -#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } +//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } //#define report(x) {} // Enable/Disble Profiler debug prints -#define DEBUG_PROFILER +//#define DEBUG_PROFILER //RED 31m // profiler prints debug messages in red //YELLOW 33m // ir prints debug messages in yellow @@ -43,7 +43,7 @@ #define debugprof(...) #else #define debugprof(...) do { \ - printf("\033[31m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ + printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ printf(__VA_ARGS__); \ printf("\033[0m\n"); \ } while (0) diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 2e04a1e8..c97f0de4 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -34,12 +34,12 @@ #include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/reference/device/tensor_fill.h" - #include "cutlass/util/reference/host/tensor_fill.h" - #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" +#include "cutlass/library/util.h" + #include "device_allocation.h" namespace cutlass { @@ -106,6 +106,18 @@ std::vector DeviceAllocation::get_packed_layout( case library::LayoutTypeID::kRowMajorInterleavedK16: stride = get_packed_layout_stride>(extent); break; + case library::LayoutTypeID::kColumnMajorInterleavedK32: + stride = get_packed_layout_stride>(extent); + break; + case library::LayoutTypeID::kRowMajorInterleavedK32: + stride = get_packed_layout_stride>(extent); + break; + case library::LayoutTypeID::kColumnMajorInterleavedK64: + stride = get_packed_layout_stride>(extent); + break; + case library::LayoutTypeID::kRowMajorInterleavedK64: + stride = get_packed_layout_stride>(extent); + break; case library::LayoutTypeID::kTensorNCHW: stride = get_packed_layout_stride(extent); break; @@ -200,6 +212,18 @@ size_t DeviceAllocation::construct_layout( case library::LayoutTypeID::kRowMajorInterleavedK16: return construct_layout_>(bytes, layout_id, extent, stride); + case library::LayoutTypeID::kColumnMajorInterleavedK32: + return construct_layout_>(bytes, layout_id, extent, stride); + + case library::LayoutTypeID::kRowMajorInterleavedK32: + return construct_layout_>(bytes, layout_id, extent, stride); + + case library::LayoutTypeID::kColumnMajorInterleavedK64: + return construct_layout_>(bytes, layout_id, extent, stride); + + case library::LayoutTypeID::kRowMajorInterleavedK64: + return construct_layout_>(bytes, layout_id, extent, stride); + case library::LayoutTypeID::kTensorNCHW: return construct_layout_(bytes, layout_id, extent, stride); @@ -415,6 +439,14 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kCF64: + cutlass::reference::device::BlockFillRandom>( + reinterpret_cast *>(pointer_), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kS8: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), @@ -508,6 +540,14 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kCF16: + cutlass::reference::host::BlockFillRandom>( + reinterpret_cast *>(host_data.data()), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kF64: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -516,6 +556,14 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kCF64: + cutlass::reference::host::BlockFillRandom>( + reinterpret_cast *>(host_data.data()), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kS8: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -607,13 +655,25 @@ bool DeviceAllocation::block_compare_equal( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); - + + case library::NumericTypeID::kCF16: + return reference::device::BlockCompareEqual>( + reinterpret_cast const *>(ptr_A), + reinterpret_cast const *>(ptr_B), + capacity); + case library::NumericTypeID::kF64: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); + case library::NumericTypeID::kCF64: + return reference::device::BlockCompareEqual>( + reinterpret_cast const *>(ptr_A), + reinterpret_cast const *>(ptr_B), + capacity); + case library::NumericTypeID::kS8: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu index f695a9ed..780e0447 100644 --- a/tools/profiler/src/device_context.cu +++ b/tools/profiler/src/device_context.cu @@ -74,16 +74,34 @@ DeviceAllocation *DeviceContext::allocate_tensor( allocate_tensor(name, type, layout_id, extent, stride); if (options.initialization.enabled) { + Distribution data_distribution = options.initialization.data_distribution; - if (options.initialization.provider == Provider::kReferenceDevice) { + // check if data distribution is allowed to change + if(!options.initialization.fix_data_distribution) { + // change data distribution based on bit width + switch(type) { + case library::NumericTypeID::kB1: + data_distribution.set_uniform(0, 2, 0); + break; + case library::NumericTypeID::kS8: + data_distribution.set_uniform(-2, 2, 0); + break; + case library::NumericTypeID::kU8: + data_distribution.set_uniform(0, 4, 0); + break; + default: break; + } + } + + if (options.initialization.provider == library::Provider::kReferenceDevice) { allocation->initialize_random_device( options.initialization.seed, - options.initialization.data_distribution); + data_distribution); } - else if (options.initialization.provider == Provider::kReferenceHost) { + else if (options.initialization.provider == library::Provider::kReferenceHost) { allocation->initialize_random_host( options.initialization.seed, - options.initialization.data_distribution); + data_distribution); } } diff --git a/tools/profiler/src/enumerated_types.cpp b/tools/profiler/src/enumerated_types.cpp index 7ca41789..1acefb1f 100644 --- a/tools/profiler/src/enumerated_types.cpp +++ b/tools/profiler/src/enumerated_types.cpp @@ -123,53 +123,6 @@ AlgorithmMode from_string(std::string const &str) { ///////////////////////////////////////////////////////////////////////////////////////////////// - -static struct { - char const *text; - char const *pretty; - Provider enumerant; -} -Provider_enumerants[] = { - {"cutlass", "CUTLASS", Provider::kCUTLASS}, - {"host", "reference_host", Provider::kReferenceHost}, - {"device", "reference_device", Provider::kReferenceDevice}, - {"cublas", "cuBLAS", Provider::kCUBLAS}, -}; - -/// Converts a Provider enumerant to a string -char const *to_string(Provider provider, bool pretty) { - - for (auto const & possible : Provider_enumerants) { - if (provider == possible.enumerant) { - if (pretty) { - return possible.pretty; - } - else { - return possible.text; - } - } - } - - return pretty ? "Invalid" : "invalid"; -} - -/// Parses a Provider enumerant from a string -template <> -Provider from_string(std::string const &str) { - - for (auto const & possible : Provider_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { - return possible.enumerant; - } - } - - return Provider::kInvalid; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - - static struct { char const *text; char const *pretty; @@ -180,6 +133,7 @@ Disposition_enumerants[] = { {"failed", "Failed", Disposition::kFailed}, {"not_run", "Not run", Disposition::kNotRun}, {"not_verified", "Not verified", Disposition::kNotVerified}, + {"invalid_problem", "Invalid problem", Disposition::kInvalidProblem}, {"not_supported", "Not supported", Disposition::kNotSupported}, {"incorrect", "Incorrect", Disposition::kIncorrect} }; diff --git a/tools/profiler/src/enumerated_types.h b/tools/profiler/src/enumerated_types.h index f9b19423..051406d1 100644 --- a/tools/profiler/src/enumerated_types.h +++ b/tools/profiler/src/enumerated_types.h @@ -30,7 +30,9 @@ #include #include +#include #include +#include "cutlass/library/library.h" #define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } @@ -79,26 +81,6 @@ AlgorithmMode from_string(std::string const &str); ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Providers -enum class Provider { - kCUTLASS, - kReferenceHost, - kReferenceDevice, - kCUBLAS, - kInvalid -}; - -using ProviderVector = std::vector; - -/// Converts a Provider enumerant to a string -char const *to_string(Provider provider, bool pretty = false); - -/// Parses a Provider enumerant from a string -template <> -Provider from_string(std::string const &str); - -///////////////////////////////////////////////////////////////////////////////////////////////// - /// Outcome of a performance test enum class Disposition { kPassed, @@ -106,12 +88,13 @@ enum class Disposition { kNotRun, kIncorrect, kNotVerified, + kInvalidProblem, kNotSupported, kInvalid }; /// Converts a Disposition enumerant to a string -char const *to_string(Disposition provider, bool pretty = false); +char const *to_string(Disposition disposition, bool pretty = false); /// Parses a Disposition enumerant from a string template <> @@ -159,6 +142,21 @@ char const *to_string(ArgumentTypeID type, bool pretty = false); template <> ArgumentTypeID from_string(std::string const &str); +///////////////////////////////////////////////////////////////////////////////////////////////// +// Profiler typedefs +using ProviderVector = std::vector; +using DispositionMap = std::map; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Print vector for the report +template +std::ostream& operator<< (std::ostream& out, const std::vector& v) { + for(int i = 0; i < v.size(); ++i) { + out << to_string(v[i], true) << (i+1 != v.size() ? "," : ""); + } + return out; +} ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace profiler diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 4c8cb86a..cb430968 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -140,7 +140,6 @@ Status GemmOperationProfiler::initialize_configuration( return Status::kErrorInvalidProblem; } - if (!arg_as_int(problem_.m, "m", problem_space, problem)) { // default value problem_.m = 1024; @@ -201,7 +200,7 @@ Status GemmOperationProfiler::initialize_configuration( return Status::kErrorInternal; } } - + problem_.lda = DeviceAllocation::get_packed_layout( operation_desc.A.layout, {int(problem_.m), int(problem_.k)}).front(); @@ -240,7 +239,7 @@ void GemmOperationProfiler::initialize_result_( library::GemmDescription const &operation_desc, ProblemSpace const &problem_space) { - result.provider = Provider::kCUTLASS; + result.provider = library::Provider::kCUTLASS; result.disposition = Disposition::kNotRun; result.status = Status::kSuccess; result.operation_name = operation_desc.name; @@ -277,9 +276,17 @@ void GemmOperationProfiler::initialize_result_( int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n * 2; result.flops = 2 * (problem_.m * problem_.n * problem_.k + problem_.m * problem_.n); - result.runtime = 0; + // complex-valued support + switch (operation_desc.tile_description.math_instruction.math_operation) { + case library::MathOperationID::kMultiplyAddComplex: + result.flops *= 4; + break; + + default: break; + } + } /// Initializes workspace @@ -290,7 +297,7 @@ Status GemmOperationProfiler::initialize_workspace( library::Operation const *operation, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - + library::GemmDescription const &operation_desc = static_cast(operation->description()); @@ -348,7 +355,7 @@ Status GemmOperationProfiler::initialize_workspace( // Status status = Status::kSuccess; - if (options.profiling.provider_enabled(Provider::kCUTLASS)) { + if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { if (options.execution_mode != ExecutionMode::kDryRun) { @@ -368,8 +375,12 @@ Status GemmOperationProfiler::initialize_workspace( // If CUTLASS is enabled, generate a result for it // results_.push_back(model_result_); - results_.back().provider = Provider::kCUTLASS; + results_.back().provider = library::Provider::kCUTLASS; + results_.back().op_kind = library::OperationKind::kGemm; results_.back().disposition = Disposition::kNotRun; + for(auto &verification_provider : options.verification.providers) { + results_.back().verification_map[verification_provider] = Disposition::kNotRun; + } } return status; @@ -386,7 +397,7 @@ bool GemmOperationProfiler::verify_cutlass( ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - if (!options.profiling.provider_enabled(Provider::kCUTLASS)) { + if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { return true; } @@ -423,198 +434,62 @@ bool GemmOperationProfiler::verify_cutlass( return false; } + // CUTLASS op ran the but not yet verified against any verification provider results_.back().disposition = Disposition::kNotVerified; + // + // Run verification providers + // + if (options.verification.enabled) { #if CUTLASS_ENABLE_CUBLAS - if (options.verification.provider_enabled(Provider::kCUBLAS)) { + if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { // Guard against unsupported cases auto const & gemm_desc = static_cast(operation->description()); - if (cublas_satisfies(gemm_desc) != Status::kSuccess) { - return true; - } + if (cublas_satisfies(gemm_desc) == Status::kSuccess) { - return verify_with_cublas_( - options, - report, - device_context, - operation, - problem_space, - problem); + // call cublas verification if supported + verify_with_cublas_( + options, + report, + device_context, + operation, + problem_space, + problem); + } + + else { + // set verification map for cublas to not supported + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; + } } #endif // #if CUTLASS_ENABLE_CUBLAS + // Update disposition to worst case verification outcome among all + // verification providers which are supported + bool is_any_verification_run_passed = false; + for(auto &m : results_.back().verification_map) { + if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { + results_.back().disposition = m.second; + return true; + } + if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { + is_any_verification_run_passed = true; + } + } + + if(is_any_verification_run_passed) { + results_.back().disposition = Disposition::kPassed; + } } + // Return true means continue profiling return true; } - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if CUTLASS_ENABLE_CUBLAS - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Selects one or more cuBLAS algorithms. -static void select_cublas_algorithms( - std::vector &algorithms, - Options const &options, - library::GemmDescription const &op_desc) { - - library::OpcodeClassID const & opcode_class = - op_desc.tile_description.math_instruction.opcode_class; - - switch (options.library.algorithm_mode) { - case AlgorithmMode::kMatching: - { - algorithms.push_back(get_cublas_gemm_algo( - op_desc.tile_description.threadblock_shape.m(), - op_desc.tile_description.threadblock_shape.n(), - op_desc.tile_description.threadblock_shape.k(), - opcode_class)); - break; - } - - case AlgorithmMode::kBest: - { - // Choose first enumerated mode. If none are enumerated, choose based on opcode class - // and evaluate all of them. - - if (options.library.algorithms.empty()) { - // Enumerate all algorithms - if (opcode_class == library::OpcodeClassID::kSimt) { - - for (int algo = CUBLAS_GEMM_DEFAULT; - algo <= CUBLAS_GEMM_ALGO23; - ++algo) { - - algorithms.push_back(cublasGemmAlgo_t(algo)); - } - } - else { - - for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; - ++algo) { - - algorithms.push_back(cublasGemmAlgo_t(algo)); - } - } - } - else { - // Use the listed algorithms - algorithms.reserve(options.library.algorithms.size()); - - for (int algo : options.library.algorithms) { - algorithms.push_back(reinterpret_cast(algo)); - } - } - - break; - } - - case AlgorithmMode::kDefault: - { - - // Use the library's default algorithm - algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? - CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - break; - } - default: - { - break; - } - } -} - -/// Dispatcher to cublasGemmEx() -struct cublasGemmExDispatcher { - - // - // Data members - // - library::GemmConfiguration configuration; - library::GemmArguments arguments; - - cublasOperation_t trans_A; - cublasOperation_t trans_B; - cudaDataType_t data_type_A; - cudaDataType_t data_type_B; - cudaDataType_t data_type_C; - cudaDataType_t compute_type; - cublasGemmAlgo_t algo; - Status status; - - // - // Methods - // - - cublasGemmExDispatcher( - library::GemmDescription const &op_desc, - library::GemmConfiguration configuration_, - library::GemmArguments arguments_, - cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT - ): - configuration(configuration_), arguments(arguments_), algo(algorithm), status(Status::kSuccess) { - - trans_A = get_cublas_transpose_operation(op_desc.A.layout); - trans_B = get_cublas_transpose_operation(op_desc.B.layout); - - bool good = true; - good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); - good = (good && get_cublas_datatype(data_type_B, op_desc.B.element)); - good = (good && get_cublas_datatype(data_type_C, op_desc.C.element)); - - good = (good && get_cublas_datatype( - compute_type, - op_desc.tile_description.math_instruction.element_accumulator)); - - if (!good) { - status = Status::kErrorNotSupported; - } - } - - /// Executes GEMM using these arguments - cublasStatus_t operator()(cublasHandle_t handle) { - - return cublasGemmEx( - handle, - trans_A, - trans_B, - configuration.problem_size.m(), - configuration.problem_size.n(), - configuration.problem_size.k(), - arguments.alpha, - arguments.A, - data_type_A, - int(configuration.lda), - arguments.B, - data_type_B, - int(configuration.ldb), - arguments.beta, - arguments.D, - data_type_C, - int(configuration.ldc), - compute_type, - algo - ); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - -#endif // CUTLASS_ENABLE_CUBLAS - /////////////////////////////////////////////////////////////////////////////////////////////////// /// Verifies CUTLASS against references @@ -632,14 +507,16 @@ bool GemmOperationProfiler::verify_with_cublas_( library::GemmDescription const &gemm_desc = static_cast(operation->description()); + // + // Construct cuBLAS operators + // + CublasCreate handle; cublasStatus_t status = handle.get_cublas_create_status(); if (status != CUBLAS_STATUS_SUCCESS) { - results_.back().status = get_cutlass_status(status); - results_.back().disposition = Disposition::kFailed; - + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; return true; } @@ -682,7 +559,8 @@ bool GemmOperationProfiler::verify_with_cublas_( ); if (gemm_op.status != Status::kSuccess) { - results_.back().disposition = Disposition::kNotVerified; + + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; return true; } @@ -692,8 +570,8 @@ bool GemmOperationProfiler::verify_with_cublas_( // Handle errors if (status != CUBLAS_STATUS_SUCCESS) { - results_.back().status = get_cutlass_status(status); - results_.back().disposition = Disposition::kNotVerified; + + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; return true; } @@ -701,7 +579,7 @@ bool GemmOperationProfiler::verify_with_cublas_( // Verify results // - results_.back().disposition = compare_tensors( + results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( options, *gemm_workspace_.Computed, *gemm_workspace_.Reference @@ -709,19 +587,18 @@ bool GemmOperationProfiler::verify_with_cublas_( // Save workspace if incorrect if (options.verification.save_workspace == SaveWorkspace::kIncorrect && - results_.back().disposition == Disposition::kIncorrect) { + results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { save_workspace( device_context, options, gemm_desc, - Provider::kCUTLASS, - Provider::kCUBLAS); + library::Provider::kCUTLASS, + library::Provider::kCUBLAS); } } catch (...) { - results_.back().disposition = Disposition::kFailed; - results_.back().status = Status::kErrorNotSupported; + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; } #endif @@ -741,7 +618,7 @@ bool GemmOperationProfiler::profile( ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - if (options.profiling.provider_enabled(Provider::kCUTLASS)) { + if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { // Initialize structure containing GEMM arguments gemm_workspace_.arguments.A = gemm_workspace_.A->data(); diff --git a/tools/profiler/src/gemm_operation_profiler.h b/tools/profiler/src/gemm_operation_profiler.h index 37401229..3bd0bb62 100644 --- a/tools/profiler/src/gemm_operation_profiler.h +++ b/tools/profiler/src/gemm_operation_profiler.h @@ -35,6 +35,7 @@ // CUTLASS Library includes #include "cutlass/library/library.h" +#include "cutlass/library/util.h" #include "cutlass/library/manifest.h" // Profiler includes diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index df227a58..6d21f87e 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -225,7 +225,7 @@ int OperationProfiler::profile_all( ProblemSpace problem_space(arguments_, options.cmdline); // 1. Construct performance report - PerformanceReport report(options, problem_space.argument_names()); + PerformanceReport report(options, problem_space.argument_names(), kind_); // 2. For each problem in problem space ProblemSpace::Iterator problem_it = problem_space.begin(); @@ -269,7 +269,7 @@ int OperationProfiler::profile_all( if (!filtered_by_name || !satisfies(operation->description(), problem_space, problem)) { continue; } - + // A. Initialize configuration Status status = this->initialize_configuration( options, @@ -278,7 +278,7 @@ int OperationProfiler::profile_all( operation, problem_space, problem); - + if (status == Status::kErrorInternal) { // Stop profiling if there was an internal error return false; @@ -341,7 +341,7 @@ int OperationProfiler::profile_all( device_context, options, operation->description(), - Provider::kCUTLASS); + library::Provider::kCUTLASS); } // @@ -434,8 +434,8 @@ void OperationProfiler::save_workspace( DeviceContext &device_context, Options const &options, library::OperationDescription const &desc, - Provider provider, - Provider verification_provider) { + library::Provider provider, + library::Provider verification_provider) { for (auto const & named_allocation : device_context) { @@ -443,10 +443,10 @@ void OperationProfiler::save_workspace( std::stringstream filename; - filename << desc.name << "_" << to_string(provider) << "_"; + filename << desc.name << "_" << library::to_string(provider) << "_"; - if (verification_provider != Provider::kInvalid) { - filename << "verified_by_" << to_string(verification_provider) << "_"; + if (verification_provider != library::Provider::kInvalid) { + filename << "verified_by_" << library::to_string(verification_provider) << "_"; } filename << named_allocation.first + ".mat"; @@ -454,6 +454,7 @@ void OperationProfiler::save_workspace( std::ofstream out(filename.str()); allocation->write_tensor_csv(out); + out << "\n"; if (options.report.verbose) { std::cout << "wrote '" << filename.str() << "'" << std::endl; diff --git a/tools/profiler/src/operation_profiler.h b/tools/profiler/src/operation_profiler.h index 3019f3bd..ce06b1c9 100644 --- a/tools/profiler/src/operation_profiler.h +++ b/tools/profiler/src/operation_profiler.h @@ -35,6 +35,7 @@ // CUTLASS Library includes #include "cutlass/library/library.h" +#include "cutlass/library/util.h" #include "cutlass/library/manifest.h" // Profiler includes @@ -43,6 +44,7 @@ #include "performance_result.h" #include "performance_report.h" #include "problem_space.h" +#include "debug.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -192,8 +194,8 @@ public: DeviceContext &device_context, Options const &options, library::OperationDescription const &desc, - Provider provider, - Provider verification_provider = Provider::kInvalid); + library::Provider provider, + library::Provider verification_provider = library::Provider::kInvalid); protected: diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index 367dc7a2..946e536c 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -31,6 +31,8 @@ #include "cutlass/cutlass.h" #include "cutlass/version.h" +#include "cutlass/library/util.h" + #include "options.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -161,24 +163,30 @@ Options::Initialization::Initialization(cutlass::CommandLine const &cmdline) { if (cmdline.check_cmd_line_flag("initialization-provider")) { std::string str; cmdline.get_cmd_line_argument("initialization-provider", str); - provider = from_string(str); - if (provider == Provider::kInvalid) { + provider = library::from_string(str); + if (provider == library::Provider::kInvalid) { enabled = false; } - else if (provider != Provider::kReferenceHost && provider != Provider::kReferenceDevice) { + else if (provider != library::Provider::kReferenceHost && provider != library::Provider::kReferenceDevice) { throw std::runtime_error("Unsupported intialization provider specified."); } } else { - provider = Provider::kReferenceDevice; + provider = library::Provider::kReferenceDevice; } cmdline.get_cmd_line_argument("seed", seed, 2019); if (cmdline.check_cmd_line_flag("dist")) { + // user has set the data distribution (fix data distribution once set) + fix_data_distribution = true; + // set user provided data distribution get_distribution(cmdline, "dist", data_distribution); } else { + // profiler choosen data distribution (allowed to change based on numeric types) + fix_data_distribution = false; + // set uniform data distribution with range [-4, 4] data_distribution.set_uniform(-4, 4, 0); } @@ -372,12 +380,12 @@ Options::Profiling::Profiling(cutlass::CommandLine const &cmdline) { providers.clear(); for (auto const &token : tokens) { - providers.push_back(from_string(token)); + providers.push_back(library::from_string(token)); } } else { - providers.push_back(Provider::kCUTLASS); - providers.push_back(Provider::kCUBLAS); + providers.push_back(library::Provider::kCUTLASS); + providers.push_back(library::Provider::kCUBLAS); } } @@ -412,18 +420,18 @@ void Options::Profiling::print_options(std::ostream &out, int indent) const { int j = 0; for (auto const & provider : providers) { - out << (j++ ? ", " : "") << to_string(provider); + out << (j++ ? ", " : "") << library::to_string(provider); } out << "]\n"; } /// Returns true if a provider is enabled -bool Options::Profiling::provider_enabled(Provider provider) const { +bool Options::Profiling::provider_enabled(library::Provider provider) const { return std::find(providers.begin(), providers.end(), provider) != providers.end(); } /// Returns the index of a provider if its enabled -size_t Options::Profiling::index(Provider provider) const { +size_t Options::Profiling::index(library::Provider provider) const { size_t idx = 0; for (auto const & x : providers) { if (x == provider) { @@ -461,14 +469,14 @@ Options::Verification::Verification(cutlass::CommandLine const &cmdline) { providers.clear(); for (auto const &token : tokens) { - Provider provider = from_string(token); - if (provider != Provider::kInvalid) { + library::Provider provider = library::from_string(token); + if (provider != library::Provider::kInvalid) { providers.push_back(provider); } } } else { - providers.push_back(Provider::kCUBLAS); + providers.push_back(library::Provider::kCUBLAS); } } @@ -504,18 +512,18 @@ void Options::Verification::print_options(std::ostream &out, int indent) const { int j = 0; for (auto const & provider : providers) { - out << (j++ ? ", " : "") << to_string(provider); + out << (j++ ? ", " : "") << library::to_string(provider); } out << "]\n"; } /// Returns true if a provider is enabled -bool Options::Verification::provider_enabled(Provider provider) const { +bool Options::Verification::provider_enabled(library::Provider provider) const { return std::find(providers.begin(), providers.end(), provider) != providers.end(); } /// Returns the index of a provider if its enabled -size_t Options::Verification::index(Provider provider) const { +size_t Options::Verification::index(library::Provider provider) const { size_t idx = 0; for (auto const & x : providers) { if (x == provider) { @@ -658,7 +666,7 @@ Options::Options(cutlass::CommandLine const &cmdline): // Prevent launches on the device for anything other than CUTLASS operation if (execution_mode == ExecutionMode::kTrace) { - initialization.provider = Provider::kReferenceHost; + initialization.provider = library::Provider::kReferenceHost; verification.enabled = false; profiling.enabled = false; } diff --git a/tools/profiler/src/options.h b/tools/profiler/src/options.h index 26cb93c7..4f723fa5 100644 --- a/tools/profiler/src/options.h +++ b/tools/profiler/src/options.h @@ -105,11 +105,15 @@ public: /// allocating tensors. bool enabled; + /// If true, data distribution is set by the user and is not allowed to change + /// If false, data distribution is allowed to change based on element_type (library::NumericTypeID) + bool fix_data_distribution; + /// Data distribution for input tensors Distribution data_distribution; /// Source of random tensor elements - Provider provider; + library::Provider provider; /// Random number generator seed. int seed; @@ -162,10 +166,10 @@ public: void print_options(std::ostream &out, int indent = 0) const; /// Returns true if a provider is enabled - bool provider_enabled(Provider provider) const; + bool provider_enabled(library::Provider provider) const; /// Returns the index of a provider if its enabled - size_t index(Provider provider) const; + size_t index(library::Provider provider) const; }; /// Options related to profiling @@ -196,10 +200,10 @@ public: void print_options(std::ostream &out, int indent = 0) const; /// Returns true if a provider is enabled - bool provider_enabled(Provider provider) const; + bool provider_enabled(library::Provider provider) const; /// Returns the index of a provider if its enabled - size_t index(Provider provider) const; + size_t index(library::Provider provider) const; }; /// Options related to reporting diff --git a/tools/profiler/src/performance_report.cpp b/tools/profiler/src/performance_report.cpp index fd05155f..52a82099 100644 --- a/tools/profiler/src/performance_report.cpp +++ b/tools/profiler/src/performance_report.cpp @@ -29,9 +29,15 @@ #include #include #include +#include +#include + +#include "cutlass/library/util.h" + +#include "cutlass/library/util.h" #include "performance_report.h" - +#include "debug.h" namespace cutlass { namespace profiler { @@ -57,12 +63,17 @@ namespace profiler { PerformanceReport::PerformanceReport( Options const &options, - std::vector const &argument_names + std::vector const &argument_names, + library::OperationKind const &op_kind ): - options_(options), argument_names_(argument_names), problem_index_(0), good_(true) { + options_(options), argument_names_(argument_names), problem_index_(0), good_(true), op_kind_(op_kind) { + + std::string file_name = options_.report.output_path.substr(0, options_.report.output_path.rfind(".")); + std::string file_extension = options_.report.output_path.substr(options_.report.output_path.rfind(".") + 1); + op_file_name_ = file_name + "." + to_string(op_kind_) + "." + file_extension; // - // Open output file + // Open output file for operation of PerformanceReport::op_kind // if (!options_.report.output_path.empty()) { @@ -70,17 +81,17 @@ PerformanceReport::PerformanceReport( if (options_.report.append) { - std::ifstream test_output_file(options_.report.output_path.c_str()); + std::ifstream test_output_file(op_file_name_); if (test_output_file.is_open()) { print_header = false; test_output_file.close(); } - output_file_.open(options_.report.output_path.c_str(), std::ios::app); + output_file_.open(op_file_name_, std::ios::app); } else { - output_file_.open(options_.report.output_path.c_str()); + output_file_.open(op_file_name_); } if (!output_file_.good()) { @@ -148,7 +159,7 @@ void PerformanceReport::close() { } } else if (output_file_.is_open() && options_.report.verbose) { - std::cout << "\n\nWrote results to '" << options_.report.output_path << "'" << std::endl; + std::cout << "\n\nWrote results to '" << op_file_name_ << "'" << std::endl; } } @@ -184,19 +195,30 @@ std::ostream & PerformanceReport::print_result_pretty_( out << "\n" - << " Provider: " << SHELL_COLOR_BRIGHT() << to_string(result.provider, true) << SHELL_COLOR_END() << "\n" - << " Operation: " << result.operation_name << "\n\n" - << " Disposition: " << disposition_status_color(result.disposition) << to_string(result.disposition, true) << SHELL_COLOR_END() << "\n" - << " Status: " << SHELL_COLOR_BRIGHT() << library::to_string(result.status, true) << SHELL_COLOR_END() << "\n"; + << " Provider: " << SHELL_COLOR_BRIGHT() << library::to_string(result.provider, true) << SHELL_COLOR_END() << "\n" + << " Operation: " << result.operation_name << "\n\n" + << " Status: " << SHELL_COLOR_BRIGHT() << library::to_string(result.status, true) << SHELL_COLOR_END() << "\n" + << " Verification: " << SHELL_COLOR_BRIGHT() << (options_.verification.enabled ? "ON":"OFF") << SHELL_COLOR_END() << "\n" + << " Disposition: " << disposition_status_color(result.disposition) << to_string(result.disposition, true) << SHELL_COLOR_END() << "\n\n"; + + // Display individual verification results for each verification-provider + if (options_.verification.enabled) { + + static int const indent_spaces = 22; + + for(auto & m : result.verification_map) { + out << std::right << std::setw(indent_spaces) << library::to_string(m.first, true) << ": " << to_string(m.second, true) << "\n"; + } + } out - << "\n Arguments: "; + << "\n Arguments: "; int column_idx = 0; for (auto const &arg : result.arguments) { if (!arg.second.empty()) { out << " --" << arg.first << "=" << arg.second; - column_idx += 4 + arg.first.size() + arg.second.size(); + column_idx += int(4 + arg.first.size() + arg.second.size()); if (column_idx > 90) { out << " \\\n "; column_idx = 0; @@ -206,15 +228,15 @@ std::ostream & PerformanceReport::print_result_pretty_( out << "\n\n"; out - << " Bytes: " << result.bytes << " bytes\n" - << " FLOPs: " << result.flops << " flops\n\n"; + << " Bytes: " << result.bytes << " bytes\n" + << " FLOPs: " << result.flops << " flops\n\n"; if (result.good()) { out - << " Runtime: " << result.runtime << " ms\n" - << " Memory: " << result.gbytes_per_sec() << " GiB/s\n" - << "\n Math: " << result.gflops_per_sec() << " GFLOP/s\n"; + << " Runtime: " << result.runtime << " ms\n" + << " Memory: " << result.gbytes_per_sec() << " GiB/s\n" + << "\n Math: " << result.gflops_per_sec() << " GFLOP/s\n"; } diff --git a/tools/profiler/src/performance_report.h b/tools/profiler/src/performance_report.h index 1022efac..573a049e 100644 --- a/tools/profiler/src/performance_report.h +++ b/tools/profiler/src/performance_report.h @@ -31,10 +31,14 @@ #include #include +// CUTLASS Profiler includes #include "options.h" #include "enumerated_types.h" #include "performance_result.h" +// CUTLASS Library includes +#include "cutlass/library/library.h" + namespace cutlass { namespace profiler { @@ -46,6 +50,12 @@ private: /// Reference to options Options const &options_; + /// Operation kind + library::OperationKind op_kind_; + + /// Operation file name containing performance report of op_kind + std::string op_file_name_; + /// Output file containing results std::ofstream output_file_; @@ -63,7 +73,7 @@ private: public: - PerformanceReport(Options const &options, std::vector const &argument_names); + PerformanceReport(Options const &options, std::vector const &argument_names, library::OperationKind const &op_kind); bool good() const { return good_; } diff --git a/tools/profiler/src/performance_result.h b/tools/profiler/src/performance_result.h index b710099d..23eb60f2 100644 --- a/tools/profiler/src/performance_result.h +++ b/tools/profiler/src/performance_result.h @@ -32,8 +32,12 @@ #include "cutlass/cutlass.h" +// CUTLASS Profiler includes #include "enumerated_types.h" +// CUTLASS Library includes +#include "cutlass/library/library.h" + namespace cutlass { namespace profiler { @@ -45,15 +49,22 @@ struct PerformanceResult { /// Index of problem size_t problem_index; - /// Provider - Provider provider; + /// library::Provider + library::Provider provider; - /// Outcome of test - Disposition disposition; + /// Operation kind + library::OperationKind op_kind; - /// CUTLASS status result from kernels + /// CUTLASS status result from kernels (success or failure) + // Status does information on verification Status status; + /// Outcome of verification (worst case verification result) + Disposition disposition; + + /// Outcome of verification (all verification results) + DispositionMap verification_map; + /// Operation object std::string operation_name; @@ -76,7 +87,8 @@ struct PerformanceResult { /// Ctor PerformanceResult(): problem_index(0), - provider(Provider::kInvalid), + op_kind(library::OperationKind::kInvalid), + provider(library::Provider::kInvalid), disposition(Disposition::kNotRun), status(Status::kInvalid), bytes(0), diff --git a/tools/profiler/src/problem_space.cpp b/tools/profiler/src/problem_space.cpp index 33656bef..e95b9e1b 100644 --- a/tools/profiler/src/problem_space.cpp +++ b/tools/profiler/src/problem_space.cpp @@ -27,10 +27,11 @@ */ #include -#include #include #include +#include "cutlass/library/util.h" + #include "problem_space.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -849,17 +850,16 @@ bool arg_as_OpcodeClassID( } ///////////////////////////////////////////////////////////////////////////////////////////////// - /// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. bool arg_as_scalar( std::vector &bytes, library::NumericTypeID numeric_type, KernelArgument::Value const *value_ptr) { - + if (value_ptr->not_null) { if (value_ptr->argument->description->type == ArgumentTypeID::kInteger) { int64_t int_value = static_cast(value_ptr)->value; - + // TODO - convert int64_t => destination type } else if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) { diff --git a/tools/util/CMakeLists.txt b/tools/util/CMakeLists.txt index 6cda38ac..51be4b54 100644 --- a/tools/util/CMakeLists.txt +++ b/tools/util/CMakeLists.txt @@ -31,6 +31,12 @@ target_include_directories( $ ) +target_link_libraries( + cutlass_tools_util_includes + INTERFACE + $<$:cublas> + ) + install( DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ @@ -40,3 +46,4 @@ install( TARGETS cutlass_tools_util_includes EXPORT NvidiaCutlass ) + diff --git a/tools/util/include/cutlass/util/command_line.h b/tools/util/include/cutlass/util/command_line.h index 008d0e73..31fa7f34 100644 --- a/tools/util/include/cutlass/util/command_line.h +++ b/tools/util/include/cutlass/util/command_line.h @@ -119,6 +119,16 @@ struct CommandLine { val = !(value == "0" || value == "false"); } } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val) const { + + get_cmd_line_argument(arg_name, val, val); + } /** * Obtains the value specified for a given commandline parameter --= @@ -126,7 +136,7 @@ struct CommandLine { template void get_cmd_line_argument(const char* arg_name, value_t& val, - value_t const& _default = value_t()) const { + value_t const& _default) const { using namespace std; val = _default; diff --git a/tools/util/include/cutlass/util/device_memory.h b/tools/util/include/cutlass/util/device_memory.h index e8f13d3b..52229425 100644 --- a/tools/util/include/cutlass/util/device_memory.h +++ b/tools/util/include/cutlass/util/device_memory.h @@ -40,10 +40,14 @@ namespace device_memory { /// Allocate a buffer of \p count elements of type \p T on the current CUDA device template T* allocate(size_t count = 1) { + T* ptr = 0; - size_t bytes = sizeof(T) * count; + size_t bytes = 0; + + bytes = count * sizeof(T); cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); + if (cuda_error != cudaSuccess) { throw cuda_exception("Failed to allocate memory", cuda_error); } @@ -111,13 +115,16 @@ void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) { copy_to_device(device_begin, &*begin, elements); } -/****************************************************************************** - * "Smart" device memory allocation - ******************************************************************************/ +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// -/// Device allocation abstraction that tracks size and capacity template -struct allocation { +class DeviceAllocation { +public: + /// Delete functor for CUDA device memory struct deleter { void operator()(T* ptr) { @@ -130,6 +137,7 @@ struct allocation { } }; +public: // // Data members // @@ -140,23 +148,55 @@ struct allocation { /// Smart pointer platform::unique_ptr smart_ptr; +public: + + // + // Static methods + // + + /// Static member to compute the number of bytes needed for a given number of elements + static size_t bytes(size_t elements) { + if (sizeof_bits::value < 8) { + size_t const kElementsPerByte = 8 / sizeof_bits::value; + return elements / kElementsPerByte; + } + else { + size_t const kBytesPerElement = sizeof_bits::value / 8; + return elements * kBytesPerElement; + } + } + +public: + // // Methods // /// Constructor: allocates no memory - allocation() : capacity(0) {} + DeviceAllocation() : capacity(0) {} /// Constructor: allocates \p capacity elements on the current CUDA device - allocation(size_t _capacity) : smart_ptr(allocate(_capacity)), capacity(_capacity) {} + DeviceAllocation(size_t _capacity) : + smart_ptr(device_memory::allocate(_capacity)), capacity(_capacity) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation + DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {} /// Copy constructor - allocation(allocation const &p): smart_ptr(allocate(p.capacity)), capacity(p.capacity) { - copy_device_to_device(smart_ptr.get(), p.get(), capacity); + DeviceAllocation(DeviceAllocation const &p): + smart_ptr(device_memory::allocate(p.capacity)), capacity(p.capacity) { + + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + } + + /// Move constructor + DeviceAllocation(DeviceAllocation &&p): capacity(0) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); } /// Destructor - ~allocation() { reset(); } + ~DeviceAllocation() { reset(); } /// Returns a pointer to the managed object T* get() const { return smart_ptr.get(); } @@ -173,12 +213,41 @@ struct allocation { smart_ptr.reset(); } + /// Deletes managed object, if owned, and allocates a new object + void reset(size_t _capacity) { + reset(device_memory::allocate(_capacity), _capacity); + } + /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity void reset(T* _ptr, size_t _capacity) { smart_ptr.reset(_ptr); capacity = _capacity; } + /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released. + void reallocate(size_t new_capacity) { + + platform::unique_ptr new_allocation(device_memory::allocate(new_capacity)); + + device_memory::copy_device_to_device( + new_allocation.get(), + smart_ptr.get(), + std::min(new_capacity, capacity)); + + std::swap(smart_ptr, new_allocation); + std::swap(new_capacity, capacity); + } + + /// Returns the number of elements + size_t size() const { + return capacity; + } + + /// Returns the number of bytes needed to store the allocation + size_t bytes() const { + return bytes(capacity); + } + /// Returns a pointer to the object owned by *this T* operator->() const { return smart_ptr.get(); } @@ -189,15 +258,69 @@ struct allocation { const deleter& get_deleter() const { return smart_ptr.get_deleter(); } /// Copies a device-side memory allocation - allocation & operator=(allocation const &p) { + DeviceAllocation & operator=(DeviceAllocation const &p) { if (capacity != p.capacity) { - smart_ptr.reset(allocate(p.capacity)); + smart_ptr.reset(device_memory::allocate(p.capacity)); capacity = p.capacity; } copy_device_to_device(smart_ptr.get(), p.get(), capacity); return *this; } + + /// Move assignment + DeviceAllocation & operator=(DeviceAllocation && p) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + return *this; + } + + /// Copies the entire allocation from another location in device memory. + void copy_from_device(T const *ptr) const { + copy_from_device(ptr, capacity); + } + + /// Copies a given number of elements from device memory + void copy_from_device(T const *ptr, size_t elements) const { + device_memory::copy_device_to_device(get(), ptr, elements); + } + + void copy_to_device(T *ptr) const { + copy_to_device(ptr, capacity); + } + + void copy_to_device(T *ptr, size_t elements) const { + device_memory::copy_device_to_device(ptr, get(), elements); + } + + void copy_from_host(T const *ptr) const { + copy_from_host(ptr, capacity); + } + + void copy_from_host(T const *ptr, size_t elements) const { + device_memory::copy_to_device(get(), ptr, elements); + } + + void copy_to_host(T *ptr) const { + copy_to_host(ptr, capacity); + } + + void copy_to_host(T *ptr, size_t elements) const { + device_memory::copy_to_host(ptr, get(), elements); + } }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace device_memory { + +/// Device allocation abstraction that tracks size and capacity +template +using allocation = cutlass::DeviceAllocation; + } // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/host_tensor.h b/tools/util/include/cutlass/util/host_tensor.h index 0a08b6e0..b43186a0 100644 --- a/tools/util/include/cutlass/util/host_tensor.h +++ b/tools/util/include/cutlass/util/host_tensor.h @@ -99,7 +99,7 @@ public: using ConstReference = typename ConstTensorRef::Reference; /// Used to handle packing of subbyte elements - static int const kElementsPerStoredItem = (sizeof_bits::value < 8 ? sizeof(Element) * 8 / sizeof_bits::value : 1); + static int const kElementsPerStoredItem = (sizeof_bits::value < 8 ? (8 / sizeof_bits::value) : 1); private: @@ -232,7 +232,7 @@ public: /// Returns the logical capacity based on extent and layout. May differ from size(). LongIndex capacity() const { - return layout_.capacity(extent_) * kElementsPerStoredItem; + return layout_.capacity(extent_); } /// Gets pointer to host data diff --git a/tools/util/include/cutlass/util/host_tensor_planar_complex.h b/tools/util/include/cutlass/util/host_tensor_planar_complex.h new file mode 100644 index 00000000..a5e990cf --- /dev/null +++ b/tools/util/include/cutlass/util/host_tensor_planar_complex.h @@ -0,0 +1,423 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/matrix_traits.h" + +#include "cutlass/tensor_ref_planar_complex.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensorPlanarComplex { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRefPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorViewPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + + private: + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensorPlanarComplex() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensorPlanarComplex( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensorPlanarComplex( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensorPlanarComplex() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated + + device_.reset(); + host_.clear(); + + host_.resize(count * 2); + + // Allocate memory + Element* device_memory = nullptr; + if (device_backed_) { + device_memory = device_memory::allocate(count * 2); + } + device_.reset(device_memory, device_backed_ ? count * 2 : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + + if (static_cast(new_size * 2) > host_.size()) { + reserve(new_size); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the number of elements stored in the host tensor + size_t size() const { + return host_.size() / 2; + } + + /// Returns the logical capacity based on extent and layout. May differ from size(). + LongIndex capacity() const { + return layout_.capacity(extent_); + } + + /// Stride between real and imaginary parts + LongIndex imaginary_stride() const { + return host_.size() / 2; + } + + /// Gets pointer to host data + Element * host_data() { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element * host_data_imag() { return host_.data() + imaginary_stride(); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; } + + /// Gets pointer to host data with a pointer offset + Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element const * host_data_imag() const { return host_.data() + imaginary_stride(); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to device data + Element * device_data() { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; } + + /// Gets pointer to device data + Element const * device_data() const { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { + return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_real() { + return cutlass::TensorRef(host_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_imag() { + return cutlass::TensorRef(host_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { + return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_real() { + return cutlass::TensorRef(device_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_imag() { + return cutlass::TensorRef(device_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_real() { + return cutlass::TensorView(host_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_imag() { + return cutlass::TensorView(host_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_real() { + return cutlass::TensorView(device_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_imag() { + return cutlass::TensorView(device_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_data(), device_data(), imaginary_stride() * 2); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_data(), host_data(), imaginary_stride() * 2); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h new file mode 100644 index 00000000..10ce474e --- /dev/null +++ b/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h @@ -0,0 +1,306 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/matrix_traits.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static int const kGemmPlanarComplexBlockSize = 4; + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +__global__ void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + int const kMblock = kGemmPlanarComplexBlockSize; + int const kNblock = kGemmPlanarComplexBlockSize; + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + complex accum[kMblock][kNblock]; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block < K; ++k_block) { + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC c_ij = ComplexC(); + + if (beta.real() != ScalarType() || beta.imag() != ScalarType()) { + c_ij = tensor_c.at(coord); + } + + complex src{ + ScalarType(c_ij.real()), + ScalarType(c_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + ComplexC d_ij; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag());; + + tensor_d.at(coord) = d_ij; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = kernel::kGemmPlanarComplexBlockSize; + int const kNblock = kernel::kGemmPlanarComplexBlockSize; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + 1); + + kernel::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index 0c8e1ac4..34ba2475 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -43,12 +43,12 @@ #endif // CUDA includes -#include #include // Cutlass includes #include "cutlass/cutlass.h" #include "cutlass/array.h" +#include "cutlass/complex.h" #include "cutlass/tensor_view.h" #include "cutlass/util/reference/device/tensor_foreach.h" @@ -169,6 +169,95 @@ struct RandomGaussianFunc { } }; + +template +struct RandomGaussianFunc> { + + using Element = complex; + using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Real mean_ = 0, + Real stddev_ = 1, + int int_scale_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd_r = random_normal_float(&rng_state); + FloatType rnd_i = random_normal_float(&rng_state); + rnd_r = params.mean + params.stddev * rnd_r; + rnd_i = params.mean + params.stddev * rnd_i; + + Element result; + if (params.int_scale >= 0) { + rnd_r = FloatType(IntType(rnd_r * FloatType(IntType(1) << params.int_scale))); + rnd_i = FloatType(IntType(rnd_i * FloatType(IntType(1) << params.int_scale))); + + result = { + Real(rnd_r / FloatType(IntType(1) << params.int_scale)), + Real(rnd_i / FloatType(IntType(1) << params.int_scale)) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + /// Computes a random Gaussian distribution template < typename Element, ///< Element type @@ -269,12 +358,12 @@ template ///< Element type void BlockFillRandomGaussian( Element *ptr, size_t capacity, - uint64_t seed, ///< seed for RNG - Element mean = Element(0), ///< Gaussian distribution's mean - Element stddev = Element(1), ///< Gaussian distribution's standard deviation - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. + uint64_t seed, ///< seed for RNG + typename RealType::Type mean, ///< Gaussian distribution's mean + typename RealType::Type stddev, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. using RandomFunc = detail::RandomGaussianFunc; @@ -383,6 +472,111 @@ struct RandomUniformFunc { } }; +/// Computes a random Gaussian distribution +template ///< Layout function +struct RandomUniformFunc> { + + using Element = complex; + + using FloatType = typename std::conditional< + (sizeof(Real) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Real) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType min; + int int_scale; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + FloatType max = 1, + FloatType min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), + range(static_cast(max - min_)), + min(static_cast(min_)), + int_scale(int_scale_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd_r = random_uniform_float(&rng_state); + FloatType rnd_i = random_uniform_float(&rng_state); + + rnd_r = params.min + params.range * rnd_r; + rnd_i = params.min + params.range * rnd_i; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd_r = FloatType(IntType(rnd_r * FloatType(IntType(1) << params.int_scale))); + rnd_i = FloatType(IntType(rnd_i * FloatType(IntType(1) << params.int_scale))); + + result = { + Real(rnd_r / FloatType(IntType(1) << params.int_scale)), + Real(rnd_i / FloatType(IntType(1) << params.int_scale)) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + /// Computes a random Gaussian distribution template < typename Element, ///< Element type @@ -489,8 +683,8 @@ void BlockFillRandomUniform( Element *ptr, size_t capacity, uint64_t seed, ///< seed for RNG - Element max = Element(1), ///< upper bound of distribution - Element min = Element(0), ///< lower bound for distribution + typename RealType::Type max, ///< upper bound of distribution + typename RealType::Type min, ///< lower bound for distribution int bits = -1) { ///< If non-negative, specifies number of fractional bits that /// are not truncated to zero. Permits reducing precision of /// data. @@ -976,13 +1170,15 @@ void BlockFillRandom( uint64_t seed, Distribution dist) { + using Real = typename RealType::Type; + if (dist.kind == Distribution::Gaussian) { BlockFillRandomGaussian( ptr, capacity, seed, - static_cast(dist.gaussian.mean), - static_cast(dist.gaussian.stddev), + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), dist.int_scale); } else if (dist.kind == Distribution::Uniform) { @@ -990,8 +1186,8 @@ void BlockFillRandom( ptr, capacity, seed, - static_cast(dist.uniform.max), - static_cast(dist.uniform.min), + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), dist.int_scale); } } diff --git a/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/tools/util/include/cutlass/util/reference/host/gemm_complex.h index 964a69c4..0f067691 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm_complex.h +++ b/tools/util/include/cutlass/util/reference/host/gemm_complex.h @@ -72,6 +72,7 @@ void GemmComplex( ComplexTransform transform_b, ScalarType beta, TensorRef tensor_c, + TensorRef tensor_d, ComputeType initial_accum) { static_assert( @@ -138,7 +139,7 @@ void GemmComplex( if (row < M && col < N) { - tensor_c.at(coord) = convert_op( + tensor_d.at(coord) = convert_op( alpha * ScalarType(accum[i][j]) + beta * ScalarType(tensor_c.at(coord))); } @@ -171,9 +172,10 @@ void GemmComplex( TensorRef tensor_b, ComplexTransform transform_b, ScalarType beta, - TensorRef tensor_c) { + TensorRef tensor_c, + TensorRef tensor_d) { - GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, ScalarType(0)); + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h new file mode 100644 index 00000000..4d02747d --- /dev/null +++ b/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h @@ -0,0 +1,223 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/matrix_traits.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + complex accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC d_ij = tensor_c.at(coord); + + complex src{ + ScalarType(d_ij.real()), + ScalarType(d_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag());; + + tensor_d.at(coord) = d_ij; + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/tensor_compare.h b/tools/util/include/cutlass/util/reference/host/tensor_compare.h index 3c7d95ff..bf05a099 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_compare.h @@ -33,6 +33,9 @@ // Cutlass includes #include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" + #include "cutlass/util/distribution.h" //#include "cutlass/util/type_traits.h" #include "tensor_foreach.h" @@ -112,6 +115,46 @@ bool TensorEquals( return bool(func); } +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -137,6 +180,17 @@ bool TensorNotEquals( return !bool(func); } +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + return !TensorEquals(lhs, rhs); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index 37096f73..b298e4c2 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -38,6 +38,8 @@ #include "cutlass/complex.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" #include "cutlass/util/distribution.h" #include "tensor_foreach.h" @@ -101,6 +103,18 @@ void TensorFill( ); } +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorViewPlanarComplex dst, ///< destination tensor + cutlass::complex val = cutlass::complex(0)) { ///< value to uniformly fill it with + + TensorFill(dst.view_real(), val.real()); + TensorFill(dst.view_imag(), val.imag()); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -268,6 +282,23 @@ void TensorFillRandomGaussian( ); } +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits); + TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /// Fills a tensor with random values with a Gaussian distribution. @@ -461,6 +492,23 @@ void TensorFillRandomUniform( ); } +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + TensorFillRandomUniform(dst.view_real(), seed, max, min, bits); + TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /// Fills a tensor with random values with a uniform random distribution. @@ -774,6 +822,27 @@ void BlockFillSequential( } } +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = Element(s); + + s = int64_t(s + v) % mod; + ++i; + } +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/tensor_view_io.h b/tools/util/include/cutlass/util/tensor_view_io.h index c764c61c..590462f7 100644 --- a/tools/util/include/cutlass/util/tensor_view_io.h +++ b/tools/util/include/cutlass/util/tensor_view_io.h @@ -26,6 +26,8 @@ #include "cutlass/core_io.h" #include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/complex.h" namespace cutlass { @@ -87,13 +89,13 @@ inline std::ostream & TensorView_WriteRank( coord[rank] = idx; if (rank + 2 == Layout::kRank) { - // Write least significant ranks asa matrix with rows delimited by ";\n" - out << (idx ? ";\n" : ""); + // Write least significant ranks asa matrix with rows delimited by "\n" + out << (idx ? ",\n" : ""); TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); } else { // Higher ranks are separated by newlines - out << (idx ? "\n" : ""); + out << (idx ? ",\n\n" : ""); TensorView_WriteRank(out, view, coord, rank + 1, width); } } @@ -101,6 +103,76 @@ inline std::ostream & TensorView_WriteRank( return out; } +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + + complex x = view.at(coord); + out << x; + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by ";\n" + out << (idx ? ";\n" : ""); + TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + out << (idx ? "\n" : ""); + TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + } // namespace detail /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -143,4 +215,42 @@ inline std::ostream& operator<<( /////////////////////////////////////////////////////////////////////////////////////////////////// +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass