diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d50be45..2fd274c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,33 @@ # CUTLASS 2.x +## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22) + * Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) + * Adopt the new L2 prefetch feature in [cp.async](/include/cutlass/arch/memory.h) and [global load](/include/cutlass/arch/memory_sm80.h) + * Fused operators with GEMM and Convolution + * [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) + * [Fused partial reduction in epilogue](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) + * 64b tensor strides and leading dimensions support for GEMMs + * Affine rank=2 matrix layouts + * Row stride and column stride for matrices using [cutlass::layout::AffineRank2](/include/cutlass/layout/matrix.h) + * Support [FP64 tensor core](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM. + * [Batched GEMV](/test/unit/gemm/device/gemv.cu) preview implementation + * [New strided Dgrad](test/unit/gemm/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation + * Accelerates over previous implementation by cutting down redundant math by 4x + * Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` iterface + * Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores) + * Updates to [quaternion.h](/include/cutlass/quaternion.h) and [functional.h](/include/cutlass/functional.h) + * SDK Example for [GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_gemm/quaternion_conv.cu) + * [Unit tests for GEMM](/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu) + * Many improvements to the epilogue. + * Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations + * Performance improvement for FP16 tensor core kernels + * Bug fixes + * Updated minimum CUDA Toolkit requirement to 10.2 + * [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended + * Corrections and bug fixes reported by the CUTLASS community + * Thank you for filing these issues! + ## [2.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.5.0) (2021-02-26) * Tensor reductions * _m_-to-_n_ reductions of tensors with affine layout diff --git a/CMakeLists.txt b/CMakeLists.txt index 4abf54a9..895e45f8 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,9 +32,15 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") -project(CUTLASS VERSION 2.5.0 LANGUAGES CXX) +project(CUTLASS VERSION 2.6.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) +if (CUDA_VERSION VERSION_LESS 10.2) + message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 10.2 or higher, and strongly recommends CUDA 11.0 or higher.") +elseif (CUDA_VERSION VERSION_LESS 11.0) + message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.0 or higher.") +endif() + find_package(Doxygen QUIET) # @@ -105,7 +111,7 @@ endif() if (NOT CUDA_VERSION VERSION_LESS 11.0) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80) endif() -if (NOT CUDA_VERSION VERSION_LESS 11.1) +if (NOT CUDA_VERSION VERSION_LESS 11.1 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86) endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") @@ -275,7 +281,14 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") endif() + # There are numerous Clang versions that can work with each CUDA toolkit and the + # the checks are not very useful so we are turning them off and using testing to + # ensure the various combinations work properly. + list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}) + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__NV_NO_HOST_COMPILER_CHECK=1) + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unknown-cuda-version) + list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -pragma-unroll-threshold=100000) list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000) list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument) @@ -294,18 +307,28 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") link_libraries(nvidia::cudart) endif() +if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18) + # CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this + # property for CMake 3.18+, so we request the NEW behavior for correct compatibility. + # https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104 + cmake_policy(SET CMP0104 NEW) +endif() + function(cutlass_apply_cuda_gencode_flags TARGET) set(NVCC_FLAGS) set(CLANG_FLAGS) + set(__CMAKE_CUDA_ARCHS) foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED}) list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH}) set(CODES) if(CUTLASS_NVCC_EMBED_CUBIN) list(APPEND CODES sm_${ARCH}) + list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-real) endif() if(CUTLASS_NVCC_EMBED_PTX) list(APPEND CODES compute_${ARCH}) + list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-virtual) endif() list(JOIN CODES "," CODES_STR) list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}]) @@ -317,6 +340,8 @@ function(cutlass_apply_cuda_gencode_flags TARGET) PRIVATE $<$:${CLANG_FLAGS}> ) + elseif(CMAKE_VERSION GREATER_EQUAL 3.18) + set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS}) else() target_compile_options( ${TARGET} @@ -542,10 +567,14 @@ function(cutlass_add_executable_tests NAME TARGET) # set(options DISABLE_EXECUTABLE_INSTALL_RULE) - set(oneValueArgs) + set(oneValueArgs DISABLE_TESTS) set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - + + if (NOT DEFINED __DISABLE_TESTS) + set(__DISABLE_TESTS OFF) + endif() + if (NOT __DISABLE_EXECUTABLE_INSTALL_RULE AND CUTLASS_INSTALL_TESTS) # file(RELATIVE_PATH CMAKE_CURRENT_BINARY_RELATIVE_DIR ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}) @@ -610,6 +639,8 @@ function(cutlass_add_executable_tests NAME TARGET) COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $ ${CMD_OPTIONS} ) + set_tests_properties(c${TEST_NAME} PROPERTIES DISABLED ${__DISABLE_TESTS}) + if (CUTLASS_INSTALL_TESTS) # To run the tests from an install package with tests enabled, we need to generate test files diff --git a/README.md b/README.md index bf2d5c92..b488ca8e 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 2.5 +# CUTLASS 2.6 -_CUTLASS 2.5 - February 2021_ +_CUTLASS 2.6 - July 2021_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA. @@ -34,12 +34,24 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly. See the [functionality listing](/media/docs/functionality.md) for the list of operations supported at each level of the execution model hierarchy. +# What's New in CUTLASS 2.6 +CUTLASS 2.6 is a minor update to CUTLASS adding: +- Fused [broadcast](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) and [reductions](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) in the epilogues of GEMM and Convolution +- [Quaternion-valued GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_conv/quaternion_conv.cu) in single-precision +- [New strided Dgrad](test/unit/gemm/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation offers up to 4x performance improvements over previous strided Dgrad +- 64-bit strides for large tensor allocations +- [General affine layouts](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) fp64 tensor core and simt GEMM +- Enhanced functionality, boosted performance, and bug fixes in the epilogue. +- Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) +- Adopt new L2 prefetch feature in [ptx instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-isa-version-7-4). +- Numerous updates from the community (thanks!) +- See the [CHANGELOG](CHANGELOG.md) for more details + # What's New in CUTLASS 2.5 CUTLASS 2.5 is a minor update to CUTLASS adding: - [Tensor reductions](/test/unit/reduction/device/tensor_reduce_contiguous.cu) - [Optimizations for 3-D convolution](include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) - [Fused Convolution+Convolution example](/examples/13_two_tensor_op_fusion/README.md) -- See the [CHANGELOG](CHANGELOG.md) for more details # What's New in CUTLASS 2.4 CUTLASS 2.4 is a significant update to CUTLASS adding: @@ -52,7 +64,7 @@ CUTLASS 2.4 is a significant update to CUTLASS adding: CUTLASS 2.3 is a minor update to CUTLASS adding: - GEMMs targeting structured [Sparse Tensor Cores](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) in NVIDIA Ampere Architecture GPUs - Fast SGEMM kernels targeting GeForce RTX 30-series CUDA Cores -- Intended to be compiled with [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit) +- Intended to be compiled with [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit) or later # What's New in CUTLASS 2.2 @@ -62,7 +74,7 @@ CUTLASS 2.2 is a significant update to CUTLASS adding: - Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types - Deep software pipelines using asynchronous copy - Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) -- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) +- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) or later # What's New in CUTLASS 2.1 @@ -95,8 +107,8 @@ using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's # Compatibility CUTLASS requires a C++11 host compiler and -performs best when built with the [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit). -It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, CUDA 10.2, and CUDA 11.0. +performs best when built with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit). +It is also compatible with CUDA 10.2, CUDA 11.0, CUDA 11.1, CUDA 11.2, and CUDA 11.3. We have tested the following environments. @@ -106,12 +118,16 @@ We have tested the following environments. | | Microsoft Visual Studio 2017| | Ubuntu 16.04 | GCC 5.4.0 | | Ubuntu 18.04 | GCC 7.5.0 | +| Ubuntu 20.04 | GCC 10.2.0 | Additionally, CUTLASS may be built with clang. See [these instructions](media/docs/quickstart.md#clang) for more details. CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on -any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU. +any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU. + +For all GPUs, we recommend compiling with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) +for best performance. |**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**| |---|---|---|---| @@ -511,6 +527,7 @@ CUTLASS is released by NVIDIA Corporation as Open Source software under the The official list of CUTLASS developers and contributors is available here: [CONTRIBUTORS](CONTRIBUTORS.md). + # Copyright Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. diff --git a/cmake/CTestTestfile.config.cmake b/cmake/CTestTestfile.config.cmake index 65fda51a..0705b19c 100644 --- a/cmake/CTestTestfile.config.cmake +++ b/cmake/CTestTestfile.config.cmake @@ -17,3 +17,5 @@ add_test("@TEST_NAME@" ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" if (NOT "@TEST_EXE_WORKING_DIRECTORY@" STREQUAL "") set_tests_properties("@TEST_NAME@" PROPERTIES WORKING_DIRECTORY "@TEST_EXE_WORKING_DIRECTORY@") endif() + +set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@) diff --git a/examples/01_cutlass_utilities/cutlass_utilities.cu b/examples/01_cutlass_utilities/cutlass_utilities.cu index 8d6bf6a6..e4a11d74 100644 --- a/examples/01_cutlass_utilities/cutlass_utilities.cu +++ b/examples/01_cutlass_utilities/cutlass_utilities.cu @@ -119,12 +119,12 @@ cudaError_t cutlass_hgemm_nn( int K, cutlass::half_t alpha, cutlass::half_t const *A, - int lda, + cutlass::layout::ColumnMajor::Stride::Index lda, cutlass::half_t const *B, - int ldb, + cutlass::layout::ColumnMajor::Stride::Index ldb, cutlass::half_t beta, cutlass::half_t *C, - int ldc) { + cutlass::layout::ColumnMajor::Stride::Index ldc) { // Define the GEMM operation using Gemm = cutlass::gemm::device::Gemm< diff --git a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu index 23d5a95e..83e6fab8 100644 --- a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu +++ b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu @@ -67,7 +67,7 @@ beta * C). Now that we setup the properties of data, we have to setup properties of computation. Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x32, -64x64x4, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally +64x64x32, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally deduce the amount of threads needed per thread-block, amount of shared memory, storing data in bank-conflict free manner, and ton of other variables required to compose, intialize and launch a high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from diff --git a/examples/10_planar_complex/planar_complex.cu b/examples/10_planar_complex/planar_complex.cu index 1ee8a069..8d00b677 100644 --- a/examples/10_planar_complex/planar_complex.cu +++ b/examples/10_planar_complex/planar_complex.cu @@ -275,10 +275,10 @@ public: int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2; int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2; - int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); - int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); - int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); - int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); + typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); + typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k(); int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n(); diff --git a/examples/11_planar_complex_array/planar_complex_array.cu b/examples/11_planar_complex_array/planar_complex_array.cu index e74ba10a..6f07150f 100644 --- a/examples/11_planar_complex_array/planar_complex_array.cu +++ b/examples/11_planar_complex_array/planar_complex_array.cu @@ -292,10 +292,11 @@ public: int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2; int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2; - int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); - int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); - int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); - int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); + typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); + typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); + int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k(); int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n(); diff --git a/examples/13_two_tensor_op_fusion/README.md b/examples/13_two_tensor_op_fusion/README.md index d89d876a..556cad35 100644 --- a/examples/13_two_tensor_op_fusion/README.md +++ b/examples/13_two_tensor_op_fusion/README.md @@ -48,6 +48,10 @@ addition to its own input activation tile. Therefore the input activation warp t 2nd GEMM/Conv only depends on the output warp accumulator of the 1st GEMM/Conv in the register file, and the operation can be fully register-file-resident. +When applying the above constraint to convolutions, it is required that the 2nd Convolution +kernel doesn't have halos such that data used by each threadblock doesn't depend on any other +threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without any paddings. + # Copyright Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. diff --git a/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm75.h b/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm75.h index 305d1829..53b2fc8e 100644 --- a/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm75.h +++ b/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm75.h @@ -36,8 +36,6 @@ #include "device/b2b_implicit_gemm_convolution.h" #include "b2b_conv2d_run.h" -#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) - //////////////////////////////////////////////////////////////////////////////// cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 ( @@ -57,7 +55,7 @@ cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 ( {128, 56, 56, 64} // output size (NPQK) ); -void run_nonfused_conv2d_fprop_f16_sm75() { +bool run_nonfused_conv2d_fprop_f16_sm75() { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; @@ -90,7 +88,8 @@ void run_nonfused_conv2d_fprop_f16_sm75() { ElementC, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -135,9 +134,10 @@ void run_nonfused_conv2d_fprop_f16_sm75() { else std::cout << "Fail\n"; + return pass; } -void run_fused_conv2d_fprop_f16_sm75() { +bool run_fused_conv2d_fprop_f16_sm75() { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; @@ -161,7 +161,8 @@ void run_fused_conv2d_fprop_f16_sm75() { ElementC, InstructionShape::kM * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -207,9 +208,10 @@ void run_fused_conv2d_fprop_f16_sm75() { else std::cout << "Fail\n"; + return pass; } -void run_nonfused_conv2d_fprop_optimized_f16_sm75() { +bool run_nonfused_conv2d_fprop_optimized_f16_sm75() { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; @@ -242,7 +244,8 @@ void run_nonfused_conv2d_fprop_optimized_f16_sm75() { ElementC, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -287,9 +290,10 @@ void run_nonfused_conv2d_fprop_optimized_f16_sm75() { else std::cout << "Fail\n"; + return pass; } -void run_fused_conv2d_fprop_optimized_f16_sm75() { +bool run_fused_conv2d_fprop_optimized_f16_sm75() { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; @@ -313,7 +317,8 @@ void run_fused_conv2d_fprop_optimized_f16_sm75() { ElementC, InstructionShape::kM * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -359,10 +364,8 @@ void run_fused_conv2d_fprop_optimized_f16_sm75() { else std::cout << "Fail\n"; + return pass; } //////////////////////////////////////////////////////////////////////////////// - -#endif // if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) - diff --git a/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.h b/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.h index e14134e9..7451c5b4 100644 --- a/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.h +++ b/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.h @@ -36,8 +36,6 @@ #include "device/b2b_implicit_gemm_convolution.h" #include "b2b_conv2d_run.h" -#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - //////////////////////////////////////////////////////////////////////////////// cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 ( @@ -57,7 +55,7 @@ cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 ( {128, 56, 56, 64} // output size (NPQK) ); -void run_nonfused_conv2d_fprop_f16_sm80() { +bool run_nonfused_conv2d_fprop_f16_sm80() { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; @@ -90,7 +88,8 @@ void run_nonfused_conv2d_fprop_f16_sm80() { ElementC, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -135,9 +134,10 @@ void run_nonfused_conv2d_fprop_f16_sm80() { else std::cout << "Fail\n"; + return pass; } -void run_fused_conv2d_fprop_f16_sm80() { +bool run_fused_conv2d_fprop_f16_sm80() { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; @@ -161,7 +161,8 @@ void run_fused_conv2d_fprop_f16_sm80() { ElementC, InstructionShape::kM * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -205,9 +206,10 @@ void run_fused_conv2d_fprop_f16_sm80() { else std::cout << "Fail\n"; + return pass; } -void run_nonfused_conv2d_fprop_optimized_f16_sm80() { +bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; @@ -240,7 +242,8 @@ void run_nonfused_conv2d_fprop_optimized_f16_sm80() { ElementC, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -285,9 +288,10 @@ void run_nonfused_conv2d_fprop_optimized_f16_sm80() { else std::cout << "Fail\n"; + return pass; } -void run_fused_conv2d_fprop_optimized_f16_sm80() { +bool run_fused_conv2d_fprop_optimized_f16_sm80() { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; @@ -311,7 +315,8 @@ void run_fused_conv2d_fprop_optimized_f16_sm80() { ElementC, InstructionShape::kM * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -355,9 +360,8 @@ void run_fused_conv2d_fprop_optimized_f16_sm80() { else std::cout << "Fail\n"; + return pass; } //////////////////////////////////////////////////////////////////////////////// -#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - diff --git a/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.h b/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.h index 2cb4ac2e..c7ba4d9a 100644 --- a/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.h +++ b/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.h @@ -36,8 +36,6 @@ #include "device/b2b_implicit_gemm_convolution.h" #include "b2b_interleaved_conv2d_run.h" -#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) - //////////////////////////////////////////////////////////////////////////////// cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 ( @@ -57,7 +55,7 @@ cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 ( {128, 56, 56, 64} // output size (NPQK) ); -void run_nonfused_conv2d_fprop_s8_sm75() { +bool run_nonfused_conv2d_fprop_s8_sm75() { using ElementA = int8_t; using ElementB = int8_t; @@ -90,7 +88,8 @@ void run_nonfused_conv2d_fprop_s8_sm75() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -135,9 +134,10 @@ void run_nonfused_conv2d_fprop_s8_sm75() { else std::cout << "Fail\n"; + return pass; } -void run_fused_conv2d_fprop_s8_sm75() { +bool run_fused_conv2d_fprop_s8_sm75() { using ElementA = int8_t; using ElementB = int8_t; @@ -161,7 +161,8 @@ void run_fused_conv2d_fprop_s8_sm75() { ElementC, InstructionShape::kM * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -207,9 +208,10 @@ void run_fused_conv2d_fprop_s8_sm75() { else std::cout << "Fail\n"; + return pass; } -void run_nonfused_conv2d_fprop_optimized_s8_sm75() { +bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { using ElementA = int8_t; using ElementB = int8_t; @@ -242,7 +244,8 @@ void run_nonfused_conv2d_fprop_optimized_s8_sm75() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -287,9 +290,10 @@ void run_nonfused_conv2d_fprop_optimized_s8_sm75() { else std::cout << "Fail\n"; + return pass; } -void run_fused_conv2d_fprop_optimized_s8_sm75() { +bool run_fused_conv2d_fprop_optimized_s8_sm75() { using ElementA = int8_t; using ElementB = int8_t; @@ -313,7 +317,8 @@ void run_fused_conv2d_fprop_optimized_s8_sm75() { ElementC, InstructionShape::kM * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -359,9 +364,8 @@ void run_fused_conv2d_fprop_optimized_s8_sm75() { else std::cout << "Fail\n"; + return pass; } //////////////////////////////////////////////////////////////////////////////// -#endif // if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) - diff --git a/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.h b/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.h index c73d6c69..b1d92665 100644 --- a/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.h +++ b/examples/13_two_tensor_op_fusion/b2b_conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.h @@ -36,8 +36,6 @@ #include "device/b2b_implicit_gemm_convolution.h" #include "b2b_interleaved_conv2d_run.h" -#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - //////////////////////////////////////////////////////////////////////////////// cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 ( @@ -57,7 +55,7 @@ cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 ( {128, 56, 56, 64} // output size (NPQK) ); -void run_nonfused_conv2d_fprop_s8_sm80() { +bool run_nonfused_conv2d_fprop_s8_sm80() { using ElementA = int8_t; using ElementB = int8_t; @@ -90,7 +88,8 @@ void run_nonfused_conv2d_fprop_s8_sm80() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -135,9 +134,10 @@ void run_nonfused_conv2d_fprop_s8_sm80() { else std::cout << "Fail\n"; + return pass; } -void run_fused_conv2d_fprop_s8_sm80() { +bool run_fused_conv2d_fprop_s8_sm80() { using ElementA = int8_t; using ElementB = int8_t; @@ -161,7 +161,8 @@ void run_fused_conv2d_fprop_s8_sm80() { ElementC, 8 * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -207,9 +208,10 @@ void run_fused_conv2d_fprop_s8_sm80() { else std::cout << "Fail\n"; + return pass; } -void run_nonfused_conv2d_fprop_optimized_s8_sm80() { +bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { using ElementA = int8_t; using ElementB = int8_t; @@ -242,7 +244,8 @@ void run_nonfused_conv2d_fprop_optimized_s8_sm80() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -287,9 +290,10 @@ void run_nonfused_conv2d_fprop_optimized_s8_sm80() { else std::cout << "Fail\n"; + return pass; } -void run_fused_conv2d_fprop_optimized_s8_sm80() { +bool run_fused_conv2d_fprop_optimized_s8_sm80() { using ElementA = int8_t; using ElementB = int8_t; @@ -313,7 +317,8 @@ void run_fused_conv2d_fprop_optimized_s8_sm80() { ElementC, 8 * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -359,10 +364,9 @@ void run_fused_conv2d_fprop_optimized_s8_sm80() { else std::cout << "Fail\n"; + return pass; } //////////////////////////////////////////////////////////////////////////////// -#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - diff --git a/examples/13_two_tensor_op_fusion/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h b/examples/13_two_tensor_op_fusion/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h index 50da709e..e0e2f456 100644 --- a/examples/13_two_tensor_op_fusion/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h +++ b/examples/13_two_tensor_op_fusion/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h @@ -39,14 +39,12 @@ #include "device/b2b_gemm.h" #include "b2b_gemm_run.h" -#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) - //////////////////////////////////////////////////////////////////////////////// cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*1600, 64, 576); cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*1600, 128, 64); -void run_nonfused_gemm_f16() { +bool run_nonfused_gemm_f16() { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; @@ -80,7 +78,8 @@ void run_nonfused_gemm_f16() { ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2 @@ -116,9 +115,11 @@ void run_nonfused_gemm_f16() { std::cout << "Pass\n"; else std::cout << "Fail\n"; + + return pass; } -void run_fused_gemm_f16() { +bool run_fused_gemm_f16() { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; @@ -140,7 +141,8 @@ void run_fused_gemm_f16() { ElementOutput, InstructionShape::kM * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -183,7 +185,6 @@ void run_fused_gemm_f16() { else std::cout << "Fail\n"; + return passed; } //////////////////////////////////////////////////////////////////////////////// - -#endif //#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) diff --git a/examples/13_two_tensor_op_fusion/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm80.h b/examples/13_two_tensor_op_fusion/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm80.h index 749ece2b..3a64da84 100644 --- a/examples/13_two_tensor_op_fusion/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm80.h +++ b/examples/13_two_tensor_op_fusion/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm80.h @@ -39,14 +39,12 @@ #include "device/b2b_gemm.h" #include "b2b_gemm_run.h" -#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - //////////////////////////////////////////////////////////////////////////////// cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*1600, 64, 576); cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*1600, 128, 64); -void run_nonfused_gemm_f16_sm80() { +bool run_nonfused_gemm_f16_sm80() { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; @@ -80,7 +78,8 @@ void run_nonfused_gemm_f16_sm80() { ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3 @@ -116,9 +115,11 @@ void run_nonfused_gemm_f16_sm80() { std::cout << "Pass\n"; else std::cout << "Fail\n"; + + return pass; } -void run_fused_gemm_f16_sm80() { +bool run_fused_gemm_f16_sm80() { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; @@ -140,7 +141,8 @@ void run_fused_gemm_f16_sm80() { ElementOutput, InstructionShape::kM * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -183,7 +185,7 @@ void run_fused_gemm_f16_sm80() { else std::cout << "Fail\n"; + return passed; + } //////////////////////////////////////////////////////////////////////////////// - -#endif //#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/examples/13_two_tensor_op_fusion/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h b/examples/13_two_tensor_op_fusion/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h index 2c2610b7..c45741f0 100644 --- a/examples/13_two_tensor_op_fusion/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h +++ b/examples/13_two_tensor_op_fusion/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h @@ -39,14 +39,12 @@ #include "device/b2b_gemm.h" #include "b2b_interleaved_gemm_run.h" -#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) - //////////////////////////////////////////////////////////////////////////////// cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*1600, 64, 576); cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*1600, 128, 64); -void run_nonfused_gemm_s8() { +bool run_nonfused_gemm_s8() { using ElementOutput = int8_t; using ElementAccumulator = int32_t; @@ -80,7 +78,8 @@ void run_nonfused_gemm_s8() { ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2 @@ -116,9 +115,11 @@ void run_nonfused_gemm_s8() { std::cout << "Pass\n"; else std::cout << "Fail\n"; + + return pass; } -void run_fused_gemm_s8() { +bool run_fused_gemm_s8() { using ElementOutput = int8_t; using ElementAccumulator = int32_t; @@ -140,7 +141,8 @@ void run_fused_gemm_s8() { ElementOutput, InstructionShape::kM * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -151,8 +153,6 @@ void run_fused_gemm_s8() { ElementCompute >; - - using B2bGemm = cutlass::gemm::device::B2bGemm< int8_t, cutlass::layout::ColumnMajorInterleaved<32>, @@ -183,7 +183,7 @@ void run_fused_gemm_s8() { else std::cout << "Fail\n"; + return passed; + } //////////////////////////////////////////////////////////////////////////////// - -#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) diff --git a/examples/13_two_tensor_op_fusion/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h b/examples/13_two_tensor_op_fusion/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h index 8b9eefc6..2ded1478 100644 --- a/examples/13_two_tensor_op_fusion/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h +++ b/examples/13_two_tensor_op_fusion/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h @@ -39,14 +39,12 @@ #include "device/b2b_gemm.h" #include "b2b_interleaved_gemm_run.h" -#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - //////////////////////////////////////////////////////////////////////////////// cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*1600, 64, 576); cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*1600, 128, 64); -void run_nonfused_gemm_s8_sm80() { +bool run_nonfused_gemm_s8_sm80() { using ElementOutput = int8_t; using ElementAccumulator = int32_t; @@ -80,7 +78,8 @@ void run_nonfused_gemm_s8_sm80() { ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, @@ -106,7 +105,8 @@ void run_nonfused_gemm_s8_sm80() { ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, @@ -124,9 +124,11 @@ void run_nonfused_gemm_s8_sm80() { std::cout << "Pass\n"; else std::cout << "Fail\n"; + + return pass; } -void run_fused_gemm_s8_sm80() { +bool run_fused_gemm_s8_sm80() { using ElementOutput = int8_t; using ElementAccumulator = int32_t; @@ -148,7 +150,8 @@ void run_fused_gemm_s8_sm80() { ElementOutput, 8 * InstructionShape::kN / 32, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; using EpilogueOutputOp1 = @@ -156,11 +159,10 @@ void run_fused_gemm_s8_sm80() { ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; - - using B2bGemm = cutlass::gemm::device::B2bGemm< int8_t, cutlass::layout::ColumnMajorInterleaved<32>, @@ -183,8 +185,7 @@ void run_fused_gemm_s8_sm80() { 16, 16, false, - cutlass::arch::OpMultiplyAddSaturate, - true + cutlass::arch::OpMultiplyAddSaturate >; B2bInterleavedFusedGemmRun fusedGemm; @@ -196,7 +197,6 @@ void run_fused_gemm_s8_sm80() { else std::cout << "Fail\n"; + return passed; } //////////////////////////////////////////////////////////////////////////////// - -#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h index b72ac291..25545e4a 100644 --- a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h @@ -115,9 +115,7 @@ template < /// Operation performed by GEMM typename Operator_ = typename DefaultGemmConfiguration< OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, - ElementAccumulator_>::Operator, - /// Whether Beta is zero or not - bool IsBetaZero = false> + ElementAccumulator_>::Operator> class B2bGemm { public: @@ -148,7 +146,6 @@ class B2bGemm { static int const kAlignmentB = AlignmentB; static int const kAlignmentC = EpilogueOutputOp1::kCount; static bool const kSplitKSerial = SplitKSerial; - static bool const kIsBetaZero = IsBetaZero; static ComplexTransform const kTransformA = ComplexTransform::kNone; static ComplexTransform const kTransformB = ComplexTransform::kNone; @@ -175,8 +172,7 @@ class B2bGemm { ThreadblockSwizzle, kStages, kSplitKSerial, - Operator, - kIsBetaZero + Operator >::B2bGemmKernel; /// Argument structure @@ -422,7 +418,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h b/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h index 64f97b7b..da5e4dc2 100644 --- a/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h +++ b/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h @@ -255,7 +255,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/examples/13_two_tensor_op_fusion/fused_conv2d.cu b/examples/13_two_tensor_op_fusion/fused_conv2d.cu index f6bb3d72..a3db1c6d 100644 --- a/examples/13_two_tensor_op_fusion/fused_conv2d.cu +++ b/examples/13_two_tensor_op_fusion/fused_conv2d.cu @@ -28,53 +28,14 @@ #include "b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm75.h" #include "b2b_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.h" -int run() { - - cudaDeviceProp props; - - cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (error != cudaSuccess) { - std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return -1; - } - - if (!(props.major * 10 + props.minor >= 75)) { - std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75." - << std::endl; - - // Returning zero so this test passes on older Toolkits. Its actions are no-op. - return 0; - } - -#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - std::cout << "Running on SM80" << std::endl; - run_nonfused_conv2d_fprop_optimized_f16_sm80(); - run_fused_conv2d_fprop_optimized_f16_sm80(); - run_nonfused_conv2d_fprop_optimized_s8_sm80(); - run_fused_conv2d_fprop_optimized_s8_sm80(); -#elif defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) - std::cout << "Running on SM75" << std::endl; - run_nonfused_conv2d_fprop_optimized_f16_sm75(); - run_fused_conv2d_fprop_optimized_f16_sm75(); - run_nonfused_conv2d_fprop_optimized_s8_sm75(); - run_fused_conv2d_fprop_optimized_s8_sm75(); -#endif - - return 0; -} - -int main() { - +int run_sm75() { bool notSupported = false; // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. // // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { - std::cerr << "Tensor Core operations used in this example must be compiled with CUDA 10.2 Toolkit or later." << std::endl; - notSupported = true; - } cudaDeviceProp props; @@ -85,10 +46,7 @@ int main() { return -1; } - if (!(props.major * 10 + props.minor >= 75)) { - std::cerr << "Tensor Ops used in this example must be run on a machine with compute capability at least 75." - << std::endl; - + if (!(props.major == 7 && props.minor >= 5)) { notSupported = true; } @@ -96,7 +54,83 @@ int main() { // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } - - return run(); + + bool pass = 1; + + std::cout << "Running on SM75" << std::endl; + pass &= run_nonfused_conv2d_fprop_optimized_f16_sm75(); + pass &= run_fused_conv2d_fprop_optimized_f16_sm75(); + pass &= run_nonfused_conv2d_fprop_optimized_s8_sm75(); + pass &= run_fused_conv2d_fprop_optimized_s8_sm75(); + + if(pass) + return 1; + else + return -1; + +} + +int run_sm80() { + bool notSupported = false; + + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { + notSupported = true; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!(props.major == 8 && props.minor >= 0)) { + notSupported = true; + } + + if (notSupported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + bool pass = 1; + + std::cout << "Running on SM80" << std::endl; + pass &= run_nonfused_conv2d_fprop_optimized_f16_sm80(); + pass &= run_fused_conv2d_fprop_optimized_f16_sm80(); + pass &= run_nonfused_conv2d_fprop_optimized_s8_sm80(); + pass &= run_fused_conv2d_fprop_optimized_s8_sm80(); + + if(pass) + return 1; + else + return -1; + +} + + +int main() { + + int result = 0; + + result = run_sm80(); + + if(!result) { // not supported + result = run_sm75(); + + if(!result) { + std::cout << "This example isn't supported on current architecture" << std::endl; + } + + } + + if(result >= 0) + return 0; + else + return -1; } diff --git a/examples/13_two_tensor_op_fusion/fused_gemm.cu b/examples/13_two_tensor_op_fusion/fused_gemm.cu index 65bad943..7dd419c1 100644 --- a/examples/13_two_tensor_op_fusion/fused_gemm.cu +++ b/examples/13_two_tensor_op_fusion/fused_gemm.cu @@ -28,36 +28,15 @@ #include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h" #include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h" -int run() { - -#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - std::cout << "Running on SM80" << std::endl; - run_nonfused_gemm_f16_sm80(); - run_fused_gemm_f16_sm80(); - run_nonfused_gemm_s8_sm80(); - run_fused_gemm_s8_sm80(); -#elif defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) - std::cout << "Running on SM75" << std::endl; - run_nonfused_gemm_f16(); - run_fused_gemm_f16(); - run_nonfused_gemm_s8(); - run_fused_gemm_s8(); -#endif - - return 0; -} - -int main() { - +int run_sm75() { bool notSupported = false; // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. // // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { - std::cerr << "Tensor Core operations used in this example must be compiled with CUDA 10.2 Toolkit or later." << std::endl; - notSupported = true; + } cudaDeviceProp props; @@ -68,10 +47,7 @@ int main() { return -1; } - if (!(props.major * 10 + props.minor >= 75)) { - std::cerr << "Tensor Ops used in this example must be run on a machine with compute capability at least 75." - << std::endl; - + if (!(props.major == 7 && props.minor >= 5)) { notSupported = true; } @@ -80,6 +56,86 @@ int main() { return 0; } - return run(); + bool pass = true; + + std::cout << "Running on SM75" << std::endl; + pass &= run_nonfused_gemm_f16(); + pass &= run_fused_gemm_f16(); + pass &= run_nonfused_gemm_s8(); + pass &= run_fused_gemm_s8(); + + if(pass) + return 1; + else + return -1; + + } +int run_sm80() { + bool notSupported = false; + + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { + notSupported = true; + + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!(props.major == 8 && props.minor >= 0)) { + notSupported = true; + } + + if (notSupported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + bool pass = true; + + std::cout << "Running on SM80" << std::endl; + pass &= run_nonfused_gemm_f16_sm80(); + pass &= run_fused_gemm_f16_sm80(); + pass &= run_nonfused_gemm_s8_sm80(); + pass &= run_fused_gemm_s8_sm80(); + + if(pass) + return 1; + else + return -1; + +} + + +int main() { + + int result = 0; + + result = run_sm80(); + + if(!result) { // not supported + result = run_sm75(); + + if(!result) { + std::cout << "This example isn't supported on current architecture" << std::endl; + } + + } + + if(result >= 0) + return 0; + else + return -1; +} + + + diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h index 5627fc31..51e06a94 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h @@ -66,6 +66,7 @@ struct B2bGemm { cutlass::gemm::GemmCoord problem_size_0; cutlass::gemm::GemmCoord problem_size_1; cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; typename B2bMma::IteratorA0::Params params_A0; typename B2bMma::IteratorA0::TensorRef ref_A0; typename B2bMma::IteratorB0::Params params_B0; @@ -91,7 +92,7 @@ struct B2bGemm { // CUTLASS_HOST_DEVICE - Params(): semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0), + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0), gemm_k_iterations_1(0), gemm_k_size_1(0) { } CUTLASS_HOST_DEVICE @@ -112,6 +113,7 @@ struct B2bGemm { problem_size_0(problem_size_0), problem_size_1(problem_size_1), grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A0(ref_A0.layout()), ref_A0(ref_A0), params_B0(ref_B0.layout()), @@ -211,7 +213,7 @@ struct B2bGemm { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -315,7 +317,7 @@ struct B2bGemm { // threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h b/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h index 9a7b462a..b0bda3d0 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h @@ -209,6 +209,7 @@ struct B2bImplicitGemmConvolution { cutlass::gemm::GemmCoord grid_tiled_shape; gemm::GemmCoord implicit_gemm_problem_size_0; gemm::GemmCoord implicit_gemm_problem_size_1; + int swizzle_log_tile; int gemm_k_iterations_0; int gemm_k_iterations_1; typename B2bMma::IteratorA0::Params iterator_A0; @@ -233,7 +234,7 @@ struct B2bImplicitGemmConvolution { // CUTLASS_HOST_DEVICE - Params(): gemm_k_iterations_0(0), gemm_k_iterations_1(0) { } + Params(): swizzle_log_tile(0), gemm_k_iterations_0(0), gemm_k_iterations_1(0) { } /// CUTLASS_HOST_DEVICE @@ -245,7 +246,6 @@ struct B2bImplicitGemmConvolution { problem_size_1(args.problem_size_1), implicit_gemm_problem_size_0(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0)), implicit_gemm_problem_size_1(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1)), - grid_tiled_shape(grid_tiled_shape), iterator_A0(B2bMma::IteratorA0::getParams(args.problem_size_0, args.ref_A0.layout())), ptr_A0(args.ref_A0.data()), iterator_B0(args.problem_size_0, args.ref_B0.layout()), @@ -272,6 +272,8 @@ struct B2bImplicitGemmConvolution { implicit_gemm_problem_size_0, {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, args.problem_size_0.split_k_slices); + + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); } }; @@ -296,7 +298,7 @@ struct B2bImplicitGemmConvolution { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || @@ -379,7 +381,7 @@ struct B2bImplicitGemmConvolution { // Compute logical position within grid threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // If performing a reduction via split-K, fetch the initial synchronization if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h index cdf53756..83d9fe96 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h @@ -111,9 +111,7 @@ template < /// If true, kernel is configured to support serial reduction in the epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator, - /// Beta is zero or not - bool IsBetaZero = false + typename Operator > struct DefaultB2bGemm; @@ -321,9 +319,7 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator, - /// Is Beta zero or not - bool IsBetaZero> + typename Operator> struct DefaultB2bGemm< ElementA, layout::ColumnMajorInterleaved, kAlignmentA, ElementB, layout::RowMajorInterleaved, kAlignmentB, @@ -332,7 +328,7 @@ struct DefaultB2bGemm< ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, - SplitKSerial, Operator, IsBetaZero> { + SplitKSerial, Operator> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -353,8 +349,7 @@ struct DefaultB2bGemm< using Epilogue = typename cutlass::epilogue::threadblock:: DefaultInterleavedEpilogueTensorOp< ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, - 64 / sizeof_bits::value, InterleavedK, - IsBetaZero>::Epilogue; + 64 / sizeof_bits::value, InterleavedK>::Epilogue; /// Define the kernel-level GEMM operator. using B2bGemmKernel = kernel::B2bGemm; @@ -397,9 +392,7 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator, - /// Is Beta zero or not - bool IsBetaZero> + typename Operator> struct DefaultB2bGemm, kAlignmentA, ElementB, layout::RowMajorInterleaved, kAlignmentB, @@ -407,7 +400,7 @@ struct DefaultB2bGemm, int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, - ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> { + ThreadblockSwizzle, 2, SplitKSerial, Operator> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -426,8 +419,7 @@ struct DefaultB2bGemm, using Epilogue = typename cutlass::epilogue::threadblock:: DefaultInterleavedEpilogueTensorOp< ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, - 64 / sizeof_bits::value, InterleavedK, - IsBetaZero>::Epilogue; + 64 / sizeof_bits::value, InterleavedK>::Epilogue; /// Define the kernel-level GEMM operator. using B2bGemmKernel = kernel::B2bGemm; diff --git a/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu b/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu index 58f5a874..a27a802a 100644 --- a/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu +++ b/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu @@ -43,14 +43,122 @@ fp32 data by using NVIDIA Ampere architecture. #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm.h" + +#include "cutlass/util/command_line.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/tensor_view_io.h" + #include "helper.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + float alpha; + float beta; + + bool reference_check; + int iterations; + + Options(): + help(false), + problem_size({5120, 4096, 4096}), + batch_count(1), + reference_check(true), + iterations(20), + alpha(1), + beta() { } + + bool valid() { + return true; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + + cmd.get_cmd_line_argument("iterations", iterations); + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "14_ampere_tf32_tensorop_gemm example\n\n" + << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m GEMM M dimension\n" + << " --n GEMM N dimension\n" + << " --k GEMM K dimension\n" + << " --alpha Epilogue scalar alpha\n" + << " --beta Epilogue scalar beta\n\n" + << " --iterations Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm --m=1024 --n=512 --k=1024 \\\n" + << " --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = problem_size.product() * batch_count; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + // The code section below describes datatype for input, output matrices and computation between // elements in input matrices. using ElementAccumulator = float; // <- data type of accumulator @@ -111,14 +219,10 @@ using Gemm = cutlass::gemm::device::Gemm; -int run() { - - const int length_m = 5120; - const int length_n = 4096; - const int length_k = 4096; +int run(Options &options) { // Create a tuple of problem size for matrix multiplication - cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); + cutlass::gemm::GemmCoord problem_size = options.problem_size; // Initialize tensors using CUTLASS helper functions cutlass::HostTensor tensor_a( @@ -166,8 +270,8 @@ int run() { tensor_ref_d.sync_device(); // Initialize alpha and beta for dot product computation - ElementComputeEpilogue alpha = ElementComputeEpilogue(1); - ElementComputeEpilogue beta = ElementComputeEpilogue(0); + ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha); + ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta); // Split K dimension into 1 partitions int split_k_slices = 1; @@ -199,9 +303,74 @@ int run() { status = gemm_op.initialize(arguments, workspace.get()); CUTLASS_CHECK(status); - // Launch initialized CUTLASS kernel - status = gemm_op(); - CUTLASS_CHECK(status); + // Result structure + Result result; + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMMs + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } // Create instantiation for device reference gemm kernel cutlass::reference::device::Gemm tensor_e( cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); // Same size as the above. The above one needs to be reordered and stored in this one. - cutlass::HostTensor tensor_e_reordered( + cutlass::HostTensor tensor_e_reordered( cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); // Fill input and output matrices on host using CUTLASS helper functions diff --git a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu index 4c417bc6..0ba3cb2e 100644 --- a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu +++ b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu @@ -158,7 +158,7 @@ using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSw constexpr int NumStages = 3; // This code section describe iterator algorithm selected is Analytic or Optimized -static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kAnalytic; +static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; // This code section describes the epilogue part of the kernel, we use default value using EpilogueOp = cutlass::epilogue::thread::LinearCombination< @@ -189,7 +189,6 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; - ///////////////////////////////////////////////////////////////////////////////////////////////// // Command line options parsing @@ -755,6 +754,3 @@ int main(int argc, char const **args) { } ///////////////////////////////////////////////////////////////////////////////////////////////// - - - diff --git a/examples/18_ampere_fp64_tensorop_affine2_gemm/CMakeLists.txt b/examples/18_ampere_fp64_tensorop_affine2_gemm/CMakeLists.txt new file mode 100644 index 00000000..2c085d9b --- /dev/null +++ b/examples/18_ampere_fp64_tensorop_affine2_gemm/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +cutlass_example_add_executable( + 18_ampere_fp64_tensorop_affine2_gemm + ampere_fp64_tensorop_affine2_gemm.cu + ) + diff --git a/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu b/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu new file mode 100644 index 00000000..4b123940 --- /dev/null +++ b/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu @@ -0,0 +1,336 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +In the normal GEMM, the fast changing dimension of a matrix always has stride +equals to 1, e.g. ColumnMajor and RowMajor matrix. Affine2 matrix can have +larger than 1 stride in both dimensions. To support such layout, we need to +change to method to visit the global memory: + + 1. We can only visit 1 element a time because elements are not stored + consecutively anymore. Vectorized load/store is not possible. + 2. One extra multiplication is needed in calculating the global memory + address + addr = base_pointer + coord1 * stride1 + coord2 * stride2 + +The rest part of GEMM which includes shared memory load/store, mma comutation +is the same. + +This example uses Ampere fp64 tensore core Affine2 GEMM as an example. SIMT +(e.g. sgemm, dgemm) has support Affine2 layout. +*/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/kernel/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "helper.h" + +// The code section below describes datatype for input, output tensors and computation between +// elements +using ElementAccumulator = double; // Data type of accumulator +using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation +using ElementInputA = double; // Data type of elements in input tensor +using ElementInputB = double; // Data type of elements in input tensor +using ElementOutput = double; // Data type of elements in output tensor + +// Since Affine2 explicitly lists the strides of both dimensions, it does not really matter if +// it is columnmajor and rowmajor. However, it helps CUTLASS to improve the load locality if +// CUTLASS can know which dimension of A/B operand has smaller stride or more dense. +// +// Affine2 ColumnMajor means the row stride is smaller and Affine2 RowMajor means the column +// stride is smaller. +// +// The Affine2 epilogue reuses AffineN epilogue so it does not need to specify column majore +// or row major. +using LayoutInputA = cutlass::layout::AffineRank2ColumnMajor; +using LayoutInputB = cutlass::layout::AffineRank2RowMajor; +using LayoutOutput = cutlass::layout::AffineRankN<2>; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; // Threadblock tile shape + +// This code section describes tile size a warp will compute +using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; // Warp tile shape + +// This code section describes the size of MMA op +using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; // TensorCore instruction shape + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; + +// Number of pipelines you want to use +constexpr int NumStages = 3; + +// This code section describes the epilogue part of the kernel, we use default value +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + 1, // The number of elements per memory + // access has. It has to be 1 for + // affine2. + ElementComputeEpilogue>; + +using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmUniversal< + ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 1, // AlignmentA has to be 1 + ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 1, // AlignmentB has to be 1 + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int run() { + + // Construct Gemm ProblemSize with user defined output size + cutlass::gemm::GemmCoord problem_size = {1024, 512, 1024}; + + // Stride factor shows the distance between two elements in the differnet dimensions. The + // first data is the logical distance between two rows, the second is between two columns. + // CUTLASS has a utility tool cutlass::layout::Affine2Layout_Factory::layout_factory + // to help to convert stride_factor to the two strides. + // + // It is also totally fine to compute the strides directly without using the utility to + // construct the affine2 layout. + typename LayoutInputA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutInputB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutOutput::Stride::Index stride_factor_C[] = {7, 8}; + + // Initialize tensors using CUTLASS helper functions + cutlass::HostTensor tensor_a(problem_size.mk(), + cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), + stride_factor_A)); + cutlass::HostTensor tensor_b(problem_size.kn(), + cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), + stride_factor_B)); + + // Create matrix C used to load for bias addition. + cutlass::HostTensor tensor_c(problem_size.mn(), + cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), + stride_factor_C)); + + // Create matrix D used to store output from CUTLASS kernel + cutlass::HostTensor tensor_d(problem_size.mn(), + cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), + stride_factor_C)); + + // Create matrix D with dimensions M x N used to store output from reference + // kernel + cutlass::HostTensor tensor_ref_d(problem_size.mn(), + cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), + stride_factor_C)); + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(4), + ElementInputA(-4), + 0); // <- Fill matrix A on host with uniform-distribution random data + + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(4), + ElementInputB(-4), + 0); // <- Fill matrix B on host with uniform-distribution random data + + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(4), + ElementOutput(-4), + 0); // <- Fill matrix C on host with uniform-distribution random data + + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ref_d.sync_device(); + + // Initialize alpha for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(1); + + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; + + int batch_count = 1; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_a.device_ref().data(), // <- reference to matrix A on device + tensor_b.device_ref().data(), // <- reference to matrix B on device + tensor_c.device_ref().data(), // <- reference to matrix C on device + tensor_d.device_ref().data(), // <- reference to matrix D on device + tensor_a.layout().capacity(problem_size.mn()), + tensor_b.layout().capacity(problem_size.kn()), + tensor_c.layout().capacity(problem_size.mn()), + tensor_d.layout().capacity(problem_size.mn()), + tensor_a.layout().stride(), + tensor_b.layout().stride(), + tensor_c.layout().stride(), + tensor_d.layout().stride() + }; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + + CUTLASS_CHECK(status); + + // + // Create instantiation for device reference gemm kernel + // + + // Launch device reference to compute strictly the product A * B + cutlass::reference::device::Gemm< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementComputeEpilogue, + ElementAccumulator> gemm_device; + + gemm_device + ( + problem_size, + alpha, + tensor_a.device_ref(), + tensor_b.device_ref(), + beta, + tensor_c.device_ref(), + tensor_ref_d.device_ref() + ); + + // Wait for kernels to finish + cudaDeviceSynchronize(); + + // Copy output data from CUTLASS and reference kernel to host for comparison + tensor_d.sync_host(); + tensor_ref_d.sync_host(); + + bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(), + tensor_ref_d.host_view()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + std::cout << (pass + ? "Passed" + : "Failed") + << std::endl; + + CUTLASS_CHECK(status); + + return 0; +} + +int main(int argc, char const **args) { + + bool notSupported = false; + + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + notSupported = true; + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." + << std::endl; + notSupported = true; + } + + if (notSupported) { + return 0; + } + + return run(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/19_tensorop_canonical/CMakeLists.txt b/examples/19_tensorop_canonical/CMakeLists.txt new file mode 100644 index 00000000..297d846e --- /dev/null +++ b/examples/19_tensorop_canonical/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 19_tensorop_canonical + tensorop_canonical.cu +) + diff --git a/examples/19_tensorop_canonical/tensorop_canonical.cu b/examples/19_tensorop_canonical/tensorop_canonical.cu new file mode 100644 index 00000000..9882b477 --- /dev/null +++ b/examples/19_tensorop_canonical/tensorop_canonical.cu @@ -0,0 +1,432 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* + This example requires NVIDIA Ampere GPU or later. +*/ + +// Standard Library includes +#include +#include +#include + +// CUTLASS Includes +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" + +// CUTLASS Utility Includes +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Define the overal warp-level problem shape +int const kM = 27; +int const kN = 31; +int const kK = 17; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Define a warp-level GEMM operator. +// +// This template could be part of the CUTLASS Template Library or implemented internally. This +// wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be +// instantiated in device code. + +namespace cutlass { +namespace gemm { +namespace warp { + +template < + typename Shape, + typename InstructionShape, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementScalar +> +class GemmTensorOp { +public: + + using WarpShape = GemmShape< + ((Shape::kM + InstructionShape::kM - 1) / InstructionShape::kM) * InstructionShape::kM, + ((Shape::kN + InstructionShape::kN - 1) / InstructionShape::kN) * InstructionShape::kN, + InstructionShape::kK + >; + + using MmaWarp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, + InstructionShape, + double, // Data type of A elements + cutlass::layout::RowMajor, // Layout of A matrix + double, // Data type of B elements + cutlass::layout::ColumnMajor, // Layout of B matrix + double, // Data type of C elements + cutlass::layout::RowMajor // Layout of C matrix + >::Type; + + // Number of 'K groups' + int const kKgroups = (Shape::kK + InstructionShape::kK - 1) / InstructionShape::kK; + + // Define a 'FragmentIterator' to iterate over slices of accumulators + using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + typename MmaWarp::Shape, + InstructionShape, + double, + typename MmaWarp::Policy::Operator::FragmentC, + cutlass::layout::RowMajor + >; + + // Define an epilogue 'Tile Iteterator' to iterate over slices of elements in Shared Memory + using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpCanonical< + typename MmaWarp::Shape, + InstructionShape, + double, + cutlass::layout::RowMajor + >; + + using TensorRefA = typename MmaWarp::IteratorA::TensorRef; + using TensorRefB = typename MmaWarp::IteratorB::TensorRef; + using TensorRefC = typename AccumulatorTileIterator::TensorRef; + +public: + CUTLASS_HOST_DEVICE + GemmTensorOp() { } + + CUTLASS_DEVICE + void operator()( + ElementScalar alpha, + TensorRefA ref_A, + TensorRefB ref_B, + ElementScalar beta, + TensorRefC ref_C, + TensorRefC ref_D, + int lane_id) const { + + // Instantiate iterators pointing to slices of the A and B matrices in shared memory + typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id); + typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id); + + // Instantiate and clear accumulator tile holding the C matrix + typename MmaWarp::FragmentC accum; + accum.clear(); + + // Instantiate the warp-level matrix multiply operator + MmaWarp mma_op; + + // Instantiate fragments holding the slice of the matrix held by each warp + typename MmaWarp::FragmentA frag_A[2]; + typename MmaWarp::FragmentB frag_B[2]; + + // Load fragments from shared memory + iter_A.load(frag_A[0]); + iter_B.load(frag_B[0]); + + ++iter_A; + ++iter_B; + + // Load fragments from shared memory + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < kKgroups; ++k) { + + // Load fragments from shared memory + iter_A.load(frag_A[(k + 1) % 2]); + iter_B.load(frag_B[(k + 1) % 2]); + + ++iter_A; + ++iter_B; + + // Compute the matrix multiply + mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum); + } + + // Instantiate iterators + FragmentIterator accum_frag_it(accum); + AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id); + AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id); + + // Define function objects for linear scaling operation + cutlass::multiplies mul_source; + cutlass::multiply_add mul_add_accumulator; + + // Iterate over the epilogue components + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) { + + // Define storage for slices of the accumulators + typename FragmentIterator::Fragment accum_fragment; + typename FragmentIterator::Fragment source_fragment; + + // Select a slice of accumulators from the accumulator tile + accum_frag_it.load(accum_fragment); + ++accum_frag_it; + + // Load a corresponding slice from Shared memory + source_tile_it.load(source_fragment); + ++source_tile_it; + + // Compute linear scaling - alpha * AB + beta * C + source_fragment = mul_source(beta, source_fragment); + accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment); + + // Store the result to shared memory + dest_tile_it.store(accum_fragment); + ++dest_tile_it; + } + } +}; + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held +// in Shared Memory. +__global__ void kernel( + double *D_gmem, + double alpha, + double const *A_gmem, + double const *B_gmem, + double beta, + double const *C_gmem) { + + // Define several matrices in shared memory + __shared__ double A[kM][kK]; + __shared__ double B[kN][kK]; + __shared__ double C[kM][kN]; + + // Copy data into SMEM + if (threadIdx.x == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = 0; m < kM; ++m) { + for (int k = 0; k < kK; ++k) { + A[m][k] = A_gmem[m * kK + k]; + } + } + CUTLASS_PRAGMA_NO_UNROLL + for (int n = 0; n < kN; ++n) { + for (int k = 0; k < kK; ++k) { + B[n][k] = B_gmem[n * kK + k]; + } + } + CUTLASS_PRAGMA_NO_UNROLL + for (int m = 0; m < kM; ++m) { + CUTLASS_PRAGMA_NO_UNROLL + for (int n = 0; n < kN; ++n) { + C[m][n] = C_gmem[m * kN + n]; + } + } + } + + __syncthreads(); + + // + // Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4), + // overall shape, data type of each operand, and layout of each operand. + // + + using GemmTensorOp = cutlass::gemm::warp::GemmTensorOp< + cutlass::gemm::GemmShape, + cutlass::gemm::GemmShape<8, 8, 4>, + double, // Data type of A elements + cutlass::layout::RowMajor, // Layout of A matrix + double, // Data type of B elements + cutlass::layout::ColumnMajor, // Layout of B matrix + double, // Data type of C elements + cutlass::layout::RowMajor, // Layout of C matrix + double // Scalar type of alpha and beta + >; + + // Instantiate the GEMM operator + GemmTensorOp gemm; + + // Execute the warp-level GEMM operation + gemm( + alpha, + {&A[0][0], kK}, + {&B[0][0], kK}, + beta, + {&C[0][0], kN}, + {&C[0][0], kN}, + threadIdx.x); + + __syncthreads(); + + // Copy data into SMEM + if (threadIdx.x == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = 0; m < kM; ++m) { + CUTLASS_PRAGMA_NO_UNROLL + for (int n = 0; n < kN; ++n) { + D_gmem[m * kN + n] = C[m][n]; + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Entry point to canonical warp-level GEMM operation +int main(int argc, const char *arg[]) { + + bool notSupported = false; + + // CUTLASS must be compiled with CUDA 11 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + std::cerr << "NVIDIA Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + notSupported = true; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!((props.major * 10 + props.minor) >= 80)) { + std::cerr << "This example requires compute capability at least 80." + << std::endl; + notSupported = true; + } + + if (notSupported) { + // Return 0 so tests are considered passing if run on unsupported platforms. + return 0; + } + + cutlass::HostTensor A({kM, kK}); + cutlass::HostTensor B({kK, kN}); + cutlass::HostTensor C({kM, kN}); + cutlass::HostTensor D({kM, kN}); + + uint64_t seed = 2020; + double max = 8; + double min = -8; + + cutlass::reference::host::TensorFillRandomUniform( + A.host_view(), + seed, + max, + min, + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + B.host_view(), + seed + 17, + max, + min, + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + C.host_view(), + seed + 31, + max, + min, + 0 + ); + + A.sync_device(); + B.sync_device(); + C.sync_device(); + D.sync_device(); + + dim3 grid(1,1); + dim3 block(32, 1, 1); + + double alpha = 2.25; + double beta = 1.24; + + kernel<<< grid, block >>>( + D.device_data(), + alpha, + A.device_data(), + B.device_data(), + beta, + C.device_data() + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Failed to synchronize device after kernel launch." << std::endl; + return -1; + } + + D.sync_host(); + + // Compute reference on host + cutlass::HostTensor D_ref({kM, kN}, false); + + cutlass::reference::host::GemmComplex( + {kM, kN, kK}, + alpha, + A.host_ref(), + cutlass::ComplexTransform::kNone, + B.host_ref(), + cutlass::ComplexTransform::kNone, + beta, + C.host_ref(), + D_ref.host_ref(), + double() + ); + + // Verify reference matches computed + if (!cutlass::reference::host::TensorEquals( + D.host_view(), + D_ref.host_view())) { + + std::cerr + << "A =\n" << A.host_view() + << "\n\nB = \n" << B.host_view() + << "\n\nC = " << C.host_view() + << "\n\nRef =\n" << D_ref.host_view() + << "\n\nD =\n" << D.host_view() << "\n\n"; + + std::cerr << "Error - device results mismatch host reference." << std::endl; + + return -1; + } + + std::cout << "Passed" << std::endl; + + return 0; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/20_simt_canonical/CMakeLists.txt b/examples/20_simt_canonical/CMakeLists.txt new file mode 100644 index 00000000..f7c30275 --- /dev/null +++ b/examples/20_simt_canonical/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 20_simt_canonical + simt_canonical.cu +) + diff --git a/examples/20_simt_canonical/simt_canonical.cu b/examples/20_simt_canonical/simt_canonical.cu new file mode 100644 index 00000000..69cf0c8b --- /dev/null +++ b/examples/20_simt_canonical/simt_canonical.cu @@ -0,0 +1,419 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* + This example requires NVIDIA Maxwell GPU or beyond. +*/ + +// Standard Library includes +#include +#include +#include + +// CUTLASS Includes +#include "cutlass/cutlass.h" +#include "cutlass/core_io.h" +#include "cutlass/functional.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass/epilogue/warp/fragment_iterator_simt.h" +#include "cutlass/epilogue/warp/tile_iterator_simt.h" + +// CUTLASS Utility Includes +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Define the overal warp-level problem shape +int const kM = 14; +int const kN = 27; +int const kK = 17; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Define a warp-level GEMM operator. +// +// This template could be part of the CUTLASS Template Library or implemented internally. This +// wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be +// instantiated in device code. + +namespace cutlass { +namespace gemm { +namespace warp { + +template < + typename Shape, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementScalar +> +class GemmSimt { +public: + + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape<4, 8>, + cutlass::layout::RowMajorInterleaved<2>, + cutlass::gemm::GemmShape<4, 4, 1> + >; + + using MmaWarp = cutlass::gemm::warp::MmaSimt< + cutlass::gemm::GemmShape<16, 32, 8>, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + Policy + >; + + // Number of 'K groups' + int const kKgroups = Shape::kK; + + using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< + typename MmaWarp::Shape, + typename MmaWarp::ThreadMma, + layout::RowMajor, // SMEM layout + typename MmaWarp::Policy + >; + + using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorSimtCanonical< + typename MmaWarp::Shape, + typename MmaWarp::ThreadMma, + float, // ElementAccumulator + layout::RowMajor, // SMEM layout + typename MmaWarp::Policy + >; + + using TensorRefA = typename MmaWarp::IteratorA::TensorRef; + using TensorRefB = typename MmaWarp::IteratorB::TensorRef; + using TensorRefC = typename AccumulatorTileIterator::TensorRef; + +public: + CUTLASS_HOST_DEVICE + GemmSimt() { } + + CUTLASS_DEVICE + void operator()( + ElementScalar alpha, + TensorRefA ref_A, + TensorRefB ref_B, + ElementScalar beta, + TensorRefC ref_C, + TensorRefC ref_D, + int lane_id) const { + + // Instantiate iterators pointing to slices of the A and B matrices in shared memory + typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id); + typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id); + + // Instantiate and clear accumulator tile holding the C matrix + typename MmaWarp::FragmentC accum; + accum.clear(); + + // Instantiate the warp-level matrix multiply operator + MmaWarp mma_op; + + // Instantiate fragments holding the slice of the matrix held by each warp + typename MmaWarp::FragmentA frag_A[2]; + typename MmaWarp::FragmentB frag_B[2]; + + // Load fragments from shared memory + iter_A.load(frag_A[0]); + iter_B.load(frag_B[0]); + + ++iter_A; + ++iter_B; + + // Load fragments from shared memory + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < kKgroups; ++k) { + + // Load fragments from shared memory + iter_A.load(frag_A[(k + 1) % 2]); + iter_B.load(frag_B[(k + 1) % 2]); + + ++iter_A; + ++iter_B; + + // Compute the matrix multiply + mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum); + } + + // Instantiate iterators + FragmentIterator accum_frag_it(accum); + AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id); + AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id); + + // Define function objects for linear scaling operation + cutlass::multiplies mul_source; + cutlass::multiply_add mul_add_accumulator; + + // Iterate over the epilogue components + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) { + + // Define storage for slices of the accumulators + typename FragmentIterator::Fragment accum_fragment; + typename FragmentIterator::Fragment source_fragment; + + // Select a slice of accumulators from the accumulator tile + accum_frag_it.load(accum_fragment); + ++accum_frag_it; + + // Load a corresponding slice from Shared memory + source_tile_it.load(source_fragment); + ++source_tile_it; + + // Compute linear scaling - alpha * AB + beta * C + source_fragment = mul_source(beta, source_fragment); + accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment); + + // Store the result to shared memory + dest_tile_it.store(accum_fragment); + ++dest_tile_it; + } + + } + +}; + +} // namespace warp +} // namespace gemm +} // namespace cutlass +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held +// in Shared Memory. +__global__ void kernel( + float *D_gmem, + float alpha, + float const *A_gmem, + float const *B_gmem, + float beta, + float const *C_gmem) { + + // Define several matrices in shared memory + __shared__ float A[kM][kK]; + __shared__ float B[kN][kK]; + __shared__ float C[kM][kN]; + + // Copy data into SMEM + if (threadIdx.x == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = 0; m < kM; ++m) { + for (int k = 0; k < kK; ++k) { + A[m][k] = A_gmem[m * kK + k]; + } + } + CUTLASS_PRAGMA_NO_UNROLL + for (int n = 0; n < kN; ++n) { + for (int k = 0; k < kK; ++k) { + B[n][k] = B_gmem[n * kK + k]; + } + } + CUTLASS_PRAGMA_NO_UNROLL + for (int m = 0; m < kM; ++m) { + CUTLASS_PRAGMA_NO_UNROLL + for (int n = 0; n < kN; ++n) { + C[m][n] = C_gmem[m * kN + n]; + } + } + } + + __syncthreads(); + + // + // Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4), + // overall shape, data type of each operand, and layout of each operand. + // + + using GemmSimt = cutlass::gemm::warp::GemmSimt< + cutlass::gemm::GemmShape, + float, // Data type of A elements + cutlass::layout::RowMajor, // Layout of A matrix + float, // Data type of B elements + cutlass::layout::ColumnMajor, // Layout of B matrix + float, // Data type of C elements + cutlass::layout::RowMajor, // Layout of C matrix + float // Scalar type of alpha and beta + >; + + // Instantiate the GEMM operator + GemmSimt gemm; + + // Execute the warp-level GEMM operation + gemm( + alpha, + {&A[0][0], kK}, + {&B[0][0], kK}, + beta, + {&C[0][0], kN}, + {&C[0][0], kN}, + threadIdx.x); + + __syncthreads(); + + // Copy data into SMEM + if (threadIdx.x == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = 0; m < kM; ++m) { + CUTLASS_PRAGMA_NO_UNROLL + for (int n = 0; n < kN; ++n) { + D_gmem[m * kN + n] = C[m][n]; + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, const char *arg[]) { + + cutlass::HostTensor A({kM, kK}); + cutlass::HostTensor B({kK, kN}); + cutlass::HostTensor C({kM, kN}); + cutlass::HostTensor D({kM, kN}); + + uint64_t seed = 2020; + float max = 8; + float min = -8; + + std::cout << "Simt canonical GEMM problem size = (" << cutlass::gemm::GemmShape() <<")" << std::endl; + + cutlass::reference::host::TensorFillRandomUniform( + A.host_view(), + seed, + max, + min, + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + B.host_view(), + seed + 17, + max, + min, + 0 + ); + +#if 0 // Debug: fill A sequentially and B as Identity matrix for debugging + cutlass::reference::host::BlockFillSequential( + A.host_view().data(), A.host_view().capacity()); + + cutlass::reference::host::TensorFillIdentity(B.host_view()); +#endif + + cutlass::reference::host::TensorFillRandomUniform( + C.host_view(), + seed + 31, + max, + min, + 0 + ); + + A.sync_device(); + B.sync_device(); + C.sync_device(); + D.sync_device(); + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + float alpha = 1.0f; + float beta = 0.0f; + + kernel<<< grid, block >>>( + D.device_data(), + alpha, + A.device_data(), + B.device_data(), + beta, + C.device_data() + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Failed to synchronize device after kernel launch." << std::endl; + return -1; + } + + D.sync_host(); + + // Compute reference on host + cutlass::HostTensor D_ref({kM, kN}, false); + cutlass::reference::host::TensorCopy(D_ref.host_view(), C.host_view()); + + cutlass::reference::host::Gemm< + float, cutlass::layout::RowMajor, + float, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float, float> reference_gemm; + + reference_gemm( + {kM, kN, kK}, + alpha, + A.host_ref(), + B.host_ref(), + beta, + D_ref.host_ref(), + float() + ); + + // Verify reference matches computed + if (!cutlass::reference::host::TensorEquals( + D.host_view(), + D_ref.host_view())) { + + std::cerr + << "A =\n" << A.host_view() + << "\n\nB = \n" << B.host_view() + << "\n\nC = " << C.host_view() + << "\n\nRef =\n" << D_ref.host_view() + << "\n\nD =\n" << D.host_view() << "\n\n"; + + std::cerr << "Error - device results mismatch host reference." << std::endl; + + return -1; + } + + std::cout << "Passed" << std::endl; + + return 0; + +} +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/21_quaternion_gemm/CMakeLists.txt b/examples/21_quaternion_gemm/CMakeLists.txt new file mode 100644 index 00000000..1972da14 --- /dev/null +++ b/examples/21_quaternion_gemm/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 21_quaternion_gemm + quaternion_gemm.cu +) + diff --git a/examples/21_quaternion_gemm/quaternion_gemm.cu b/examples/21_quaternion_gemm/quaternion_gemm.cu new file mode 100644 index 00000000..5a402fb1 --- /dev/null +++ b/examples/21_quaternion_gemm/quaternion_gemm.cu @@ -0,0 +1,448 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + cutlass::gemm::GemmCoord problem_size; + int batch_count; + cutlass::Quaternion alpha; + cutlass::Quaternion beta; + + bool reference_check; + int iterations; + + Options(): + help(false), + problem_size({1024, 1024, 1024}), + batch_count(1), + reference_check(true), + iterations(20), + alpha(1), + beta() { } + + bool valid() { + return true; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch", batch_count); + + cmd.get_cmd_line_argument("alpha", alpha.w()); + cmd.get_cmd_line_argument("alpha_i", alpha.x()); + cmd.get_cmd_line_argument("alpha_j", alpha.y()); + cmd.get_cmd_line_argument("alpha_k", alpha.z()); + + cmd.get_cmd_line_argument("beta", beta.w()); + cmd.get_cmd_line_argument("beta_i", beta.x()); + cmd.get_cmd_line_argument("beta_j", beta.y()); + cmd.get_cmd_line_argument("beta_k", beta.z()); + + cmd.get_cmd_line_argument("iterations", iterations); + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "21_quaternion_gemm example\n\n" + << " This example uses the CUTLASS Library to execute Quaternion GEMM computations.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m GEMM M dimension\n" + << " --n GEMM N dimension\n" + << " --k GEMM K dimension\n" + << " --batch Number of GEMM operations executed in one batch\n" + << " --alpha Epilogue scalar alpha (real part)\n" + << " --alpha_i Epilogue scalar alpha_i (imaginary part)\n" + << " --alpha_j Epilogue scalar alpha_j (imaginary part)\n" + << " --alpha_k Epilogue scalar alpha_k (imaginary part)\n" + << " --beta Epilogue scalar beta (real part)\n\n" + << " --beta_i Epilogue scalar beta_i (imaginary part)\n\n" + << " --beta_j Epilogue scalar beta_j (imaginary part)\n\n" + << " --beta_k Epilogue scalar beta_k (imaginary part)\n\n" + << " --iterations Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/21_quaternion_gemm/21_quaternion_gemm --batch=7 --m=1024 --n=512 --k=1024 \\\n" + << " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = problem_size.product() * batch_count * 16; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. +using precision = float; +using Element = cutlass::Quaternion; +using ElementComputeEpilogue = Element; // <- data type of epilogue operations +using ElementAccumulator = Element; // <- data type of accumulator +using ElementInputA = Element; // <- data type of elements in input matrix A +using ElementInputB = Element; // <- data type of elements in input matrix B +using ElementOutput = Element; // <- data type of elements in output matrix D + +// The code section below describes matrix layout of input and output matrices. Column Major for +// Matrix A, Row Major for Matrix B and Row Major for Matrix C +using LayoutInputA = cutlass::layout::RowMajor; +using LayoutInputB = cutlass::layout::ColumnMajor; +using LayoutOutput = cutlass::layout::RowMajor; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassSimt; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm50; + +// This code section describes the tile size a thread block will compute +using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<64, 64, 4>; // <- threadblock tile M = 64, N = 64, K = 8 +// This code section describes tile size a warp will compute +using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 16, 4>; // <- warp tile M = 32, N = 16, K = 8 +// This code section describes the size of MMA op +using ShapeMMAOp = cutlass::gemm::GemmShape<1, 1, 1>; // <- MMA Op tile M = 1, N = 1, K = 1 + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- Defaults + +// This code section describes the epilogue part of the kernel +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + +// Number of pipelines you want to use +constexpr int NumStages = 2; + +using Gemm = cutlass::gemm::device::Gemm; + +int run(Options options) { + + // PASS/FAIL status + bool passed = true; + + // Create a tuple of problem size for matrix multiplication + cutlass::gemm::GemmCoord problem_size = options.problem_size; + + // Initialize tensors using CUTLASS helper functions + cutlass::HostTensor tensor_a( + problem_size.mk()); // <- Create matrix A with dimensions M x K + cutlass::HostTensor tensor_b( + problem_size.kn()); // <- Create matrix B with dimensions K x N + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + 4, + -4, + 0); // <- Fill matrix A on host with uniform-distribution random data + + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + 4, + -4, + 0); // <- Fill matrix B on host with uniform-distribution random data + + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + 4, + -4, + 0); // <- Fill matrix C on host with uniform-distribution random data + + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + // Split K dimension into 1 partitions + int split_k_slices = 1; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication + tensor_a.device_ref(), // <- reference to matrix A on device + tensor_b.device_ref(), // <- reference to matrix B on device + tensor_c.device_ref(), // <- reference to matrix C on device + tensor_d.device_ref(), // <- reference to matrix D on device + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Result structure + Result result; + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMMs + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + if (options.reference_check) { + + // Create instantiation for device reference gemm kernel + cutlass::reference::device::Gemm gemm_device; + + // Launch device reference gemm kernel + gemm_device(problem_size, + alpha, + tensor_a.device_ref(), + tensor_b.device_ref(), + beta, + tensor_c.device_ref(), + tensor_ref_d.device_ref()); + + // Wait for kernels to finish + cudaDeviceSynchronize(); + + // Copy output data from CUTLASS and reference kernel to host for comparison + tensor_d.sync_host(); + tensor_ref_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + } + + if (passed) { + std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + } + + std::cout << (passed ? "Passed" : "Failed") << std::endl; + return (passed ? 0 : -1); +} + +int main(int argc, char const** argv) { + + Options options; + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + printf("%d x %d x %d Single Precision Quaternion Matrix Multiply\n", \ + options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); + + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + return run(options); +} + diff --git a/examples/22_quaternion_conv/CMakeLists.txt b/examples/22_quaternion_conv/CMakeLists.txt new file mode 100644 index 00000000..1ff3d6ad --- /dev/null +++ b/examples/22_quaternion_conv/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +cutlass_example_add_executable( + 22_quaternion_conv + quaternion_conv.cu + ) + diff --git a/examples/22_quaternion_conv/quaternion_conv.cu b/examples/22_quaternion_conv/quaternion_conv.cu new file mode 100644 index 00000000..2439eaf8 --- /dev/null +++ b/examples/22_quaternion_conv/quaternion_conv.cu @@ -0,0 +1,660 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + +// The code section below describes datatype for input, output tensors and computation between +// elements +using Element = cutlass::Quaternion; +using ElementAccumulator = Element; // Data type of accumulator +using ElementComputeEpilogue = Element; // Data type of epilogue computation (alpha, beta) +using ElementInputA = Element; // Data type of elements in input tensor +using ElementInputB = Element; // Data type of elements in input tensor +using ElementOutput = Element; // Data type of elements in output tensor + +using LayoutInputA = cutlass::layout::TensorNHWC; +using LayoutInputB = cutlass::layout::TensorNHWC; +using LayoutOutput = cutlass::layout::TensorNHWC; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassSimt; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm50; + +// This code section describes the tile size a thread block will compute +using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; // Threadblock tile shape + +// This code section describes tile size a warp will compute +using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; // Warp tile shape + +// This code section describes the size of MMA op +using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; // SIMT instruction shape + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + +// Number of pipelines you want to use +constexpr int NumStages = 2; + +// This code section describe iterator algorithm selected is Analytic or Optimized +static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; + +// This code section describes the epilogue part of the kernel, we use default value +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue>; // Data type for alpha/beta in linear combination + + +using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm +>::Kernel; + +using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + cutlass::Tensor4DCoord input_size; + cutlass::Tensor4DCoord filter_size; + cutlass::Tensor4DCoord padding; + cutlass::MatrixCoord conv_stride; + cutlass::MatrixCoord dilation; + bool reference_check; + bool measure_performance; + int iterations; + bool save_workspace; + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + bool benchmark; + std::string tag; + + Options(): + help(false), + input_size(1, 32, 32, 32), + filter_size(32, 3, 3, 32), + padding(1, 1, 1, 1), + conv_stride(1, 1), + dilation(1, 1), + reference_check(false), + measure_performance(true), + iterations(20), + save_workspace(false), + alpha(1), + beta(0), + benchmark(false) { } + + // Verify the problem size is compatible with the CUTLASS Convolution implementation. + bool valid() { + + // + // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, + // all pointers, strides, and tensor extents must be divisible by 8 elements. + // + int const kAlignment = 8; + + if ((input_size.c() % kAlignment) || + (filter_size.n() % kAlignment)) { + + // misaligned tensors + return false; + } + + // Invalid padding + if ((padding.h() != filter_size.h() / 2) || + (padding.w() != filter_size.w() / 2)) { + + return false; + } + + return true; + } + + /// Updates input and filter sizes + void update( + cutlass::Tensor4DCoord input_size, + cutlass::Tensor4DCoord filter_size) { + + this->input_size = input_size; + this->filter_size = filter_size; + + padding.n() = filter_size.h() / 2; + padding.h() = filter_size.h() / 2; + padding.w() = filter_size.w() / 2; + padding.c() = filter_size.w() / 2; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + if (cmd.check_cmd_line_flag("ref-check")) { + reference_check = true; + } + + if (cmd.check_cmd_line_flag("perf-check")) { + measure_performance = true; + } + + if (cmd.check_cmd_line_flag("save-workspace")) { + save_workspace = true; + } + + if (cmd.check_cmd_line_flag("benchmark")) { + benchmark = true; + } + + cmd.get_cmd_line_argument("n", input_size.n()); + cmd.get_cmd_line_argument("h", input_size.h()); + cmd.get_cmd_line_argument("w", input_size.w()); + cmd.get_cmd_line_argument("c", input_size.c()); + + cmd.get_cmd_line_argument("k", filter_size.n()); + cmd.get_cmd_line_argument("r", filter_size.h()); + cmd.get_cmd_line_argument("s", filter_size.w()); + filter_size.c() = input_size.c(); + + cmd.get_cmd_line_argument("alpha_w", alpha.w()); + cmd.get_cmd_line_argument("alpha_x", alpha.x()); + cmd.get_cmd_line_argument("alpha_y", alpha.y()); + cmd.get_cmd_line_argument("alpha_z", alpha.z()); + + cmd.get_cmd_line_argument("beta_w", beta.w()); + cmd.get_cmd_line_argument("beta_x", beta.x()); + cmd.get_cmd_line_argument("beta_y", beta.y()); + cmd.get_cmd_line_argument("beta_z", beta.z()); + + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("tag", tag); + + if (filter_size.h() == 3 && filter_size.w() == 3) { + padding = {1, 1, 1, 1}; + } + else { + filter_size.h() = 1; + filter_size.w() = 1; + padding = {0, 0, 0, 0}; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "22_quaternion_conv example\n\n" + << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" + << " forward convolution on tensors of layout NHWC.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --n Input tensor extent N\n" + << " --h Input tensor extent H\n" + << " --w Input tensor extent W\n" + << " --c Input tensor extent C\n" + << " --k Filter extent K\n" + << " --r Filter extent R\n" + << " --s Filter extent S\n\n" + << " --alpha Epilogue scalar alpha\n" + << " --beta Epilogue scalar beta\n\n" + << " --ref-check If set (true), reference check on the host is computed\n" + << " --perf-check If set (true), performance is measured.\n" + << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" + << " --iterations Number of profiling iterations to perform.\n" + << " --save-workspace If set, workspace is written to a text file.\n" + << " --tag String to replicate across the first column in the results table\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/22_quaternion_conv/22_quaternion_conv --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" + << "$ ./examples/22_quaternion_conv/22_quaternion_conv --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; + + return out; + } + + /// Computes the output tensor size (NPQK) + cutlass::Tensor4DCoord output_size() const { + return cutlass::Tensor4DCoord( + input_size.n(), + (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, + (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, + filter_size.n()); + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of multiply-adds = NPQK * CRS + int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()) * 16; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct Result { + double runtime_ms; + double gflops; + cutlass::Status status; + cutlass::Status reference_check; + cudaError_t error; + + Result(): + runtime_ms(0), + gflops(0), + status(cutlass::Status::kSuccess), + reference_check(cutlass::Status::kInvalid), + error(cudaSuccess) { } + + static std::ostream & print_header(std::ostream &out, Options const &options) { + + if (!options.tag.empty()) { + out << "Name,"; + } + + out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs"; + + return out; + } + + std::ostream & print(std::ostream &out, int idx, Options const &options) { + + if (!options.tag.empty()) { + out << options.tag << ","; + } + + out + << "conv_" << idx << "," + << options.input_size.n() << "," + << options.input_size.h() << "," + << options.input_size.w() << "," + << options.input_size.c() << "," + << options.filter_size.n() << "," + << options.filter_size.h() << "," + << options.filter_size.w() << "," + << runtime_ms << "," + << gflops; + + return out; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Runs one benchmark +Result profile_convolution(Options const &options) { + + Result result; + + // + // Allocate host-device tensors using the CUTLASS Utilities. + // + + cutlass::HostTensor tensor_a(options.input_size); + cutlass::HostTensor tensor_b(options.filter_size); + cutlass::HostTensor tensor_c(options.output_size()); + cutlass::HostTensor tensor_ref_c(options.output_size()); + + // + // Initialize tensors + // + + // Fill tensor A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + 7, + -8, + 0); + + // Fill tensor B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + 7, + -8, + 0); + + // Fill tensor C on host with zeros + cutlass::reference::host::TensorFill( + tensor_c.host_view()); + + // Fill tensor C for reference on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_c.host_view()); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_ref_c.sync_device(); + + // + // Define arguments for CUTLASS Convolution + // + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + + // Split K dimension into 1 partitions + int split_k_slices = 1; + + // Construct Conv2dProblemSize with user defined output size + cutlass::conv::Conv2dProblemSize problem_size( + options.input_size, + options.filter_size, + options.padding, + options.conv_stride, + options.dilation, + options.output_size(), + mode, + split_k_slices + ); + + // Construct ImplicitGemm::Argument structure with conv2d + // problem size, data pointers, and epilogue values + typename ImplicitGemm::Arguments arguments{ + problem_size, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_c.device_ref(), + {options.alpha, options.beta}, + }; + + // + // Initialize CUTLASS Convolution + // + + ImplicitGemm implicit_gemm_op; + + size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + result.status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(result.status); + + result.status = implicit_gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(result.status); + + // + // Launch initialized CUTLASS kernel + // + result.status = implicit_gemm_op(); + + CUTLASS_CHECK(result.status); + + // + // Optional reference check + // + + if (options.reference_check) { + std::cout << "Verification on host...\n"; + + // Compute with reference implementation + cutlass::reference::host::Conv2dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementComputeEpilogue, + ElementAccumulator, + cutlass::NumericConverter + >( + problem_size, + tensor_a.host_ref(), + tensor_b.host_ref(), + tensor_c.host_ref(), + tensor_ref_c.host_ref(), + options.alpha, + options.beta + ); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + tensor_c.sync_host(); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_c.host_view(), + tensor_ref_c.host_view()); + + if (!passed) { + result.reference_check = cutlass::Status::kErrorInternal; + std::cout << "ERROR - results miscompared.\n"; + } + else { + result.reference_check = cutlass::Status::kSuccess; + std::cout << "Passed.\n"; + } + } + else { + result.reference_check = cutlass::Status::kInvalid; + } + + if (options.save_workspace) { + + std::stringstream ss; + + ss << "22_quaternion_conv_" + << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() + << "_" + << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() + << ".dat"; + + std::ofstream output_workspace(ss.str()); + + output_workspace + << "Input = \n" << tensor_a.host_view() << "\n\n" + << "Filters = \n" << tensor_b.host_view() << "\n\n"; + + if (options.reference_check) { + output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n"; + } + + output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl; + + std::cout << "Results written to '" << ss.str() << "'." << std::endl; + } + + // + // Performance measurement + // + + if (options.measure_performance) { + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + } + + // Record an event at the start of a series of convolution operations. + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Launch a sequence of implicit GEMM operations on the device + for (int iteration = 0; iteration < options.iterations; ++iteration) { + result.status = implicit_gemm_op(); + CUTLASS_CHECK(result.status); + } + + // Record an event when the convolutions have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Print average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.benchmark) { + // Benchmark several layers + + int batch_sizes[] = {1, 32, 64, 128, 256, 512}; + + struct Benchmark { + int h, w, c, k, r, s; + } layers[] = { + {56, 56, 64, 256, 1, 1}, + {56, 56, 64, 64, 1, 1}, + {56, 56, 64, 64, 3, 3}, + {56, 56, 256, 64, 1, 1}, + {56, 56, 256, 512, 1, 1}, + {56, 56, 256, 128, 1, 1}, + {28, 28, 128, 128, 3, 3}, + {28, 28, 128, 512, 1, 1}, + {28, 28, 512, 128, 1, 1}, + {28, 28, 512, 1024, 1, 1}, + {28, 28, 512, 256, 1, 1}, + {14, 14, 256, 256, 3, 3}, + {14, 14, 256, 1024, 1, 1}, + {14, 14, 1024, 256, 1, 1}, + {14, 14, 1024, 2048, 1, 1}, + {14, 14, 1024, 512, 1, 1}, + {7, 7, 512, 512, 3, 3}, + }; + + Result::print_header(std::cout, options) << std::endl; + + int idx = 1; + + for (auto const &layer : layers) { + for (auto N : batch_sizes) { + + options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c}); + + Result result = profile_convolution(options); + result.print(std::cout, idx, options) << std::endl; + } + + ++idx; + } + } + else { + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + Result result = profile_convolution(options); + + Result::print_header(std::cout, options) << std::endl; + result.print(std::cout, 1, options) << std::endl; + } + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e5bfb78c..79ce3b53 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -28,10 +28,14 @@ add_custom_target(test_examples) function(cutlass_example_add_executable NAME) set(options) - set(oneValueArgs) + set(oneValueArgs DISABLE_TESTS) set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + if (NOT DEFINED __DISABLE_TESTS) + set(__DISABLE_TESTS OFF) + endif() + cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS}) add_dependencies(cutlass_examples ${NAME}) @@ -60,6 +64,7 @@ function(cutlass_example_add_executable NAME) DEPENDEES test_examples ${__DEPENDEES} TEST_COMMAND_OPTIONS ${__TEST_COMMAND_OPTIONS} DISABLE_EXECUTABLE_INSTALL_RULE + DISABLE_TESTS ${__DISABLE_TESTS} ) endfunction() @@ -83,6 +88,11 @@ foreach(EXAMPLE 15_ampere_sparse_tensorop_gemm 16_ampere_tensorop_conv2dfprop 17_fprop_per_channel_bias + 18_ampere_fp64_tensorop_affine2_gemm + 19_tensorop_canonical + 20_simt_canonical + 21_quaternion_gemm + 22_quaternion_conv ) add_subdirectory(${EXAMPLE}) diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index 14b5c9d2..05dfa597 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -33,6 +33,26 @@ namespace cutlass { namespace arch { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + +/// Computes laneId within a warp +CUTLASS_DEVICE +int LaneId() { + int ret; + asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : ); + return ret; +} + +/// Computes SM number the thread is running on +CUTLASS_DEVICE +int SmId() { + int ret; + asm ("mov.u32 %0, %%smid;" : "=r"(ret) : ); + return ret; +} + +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// struct Sm50 { static int const kMinComputeCapability = 50; diff --git a/include/cutlass/arch/memory.h b/include/cutlass/arch/memory.h index 4abaf0d8..145e16c0 100644 --- a/include/cutlass/arch/memory.h +++ b/include/cutlass/arch/memory.h @@ -51,10 +51,20 @@ struct global_load; ///////////////////////////////////////////////////////////////////////////////////////////////// +#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ + (__CUDACC_VER_MAJOR__ > 11)) && \ + defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ + ! (defined(__clang__) && defined(__CUDA__)) + #define CUTLASS_ENABLE_L2_PREFETCH 1 +#else + #define CUTLASS_ENABLE_L2_PREFETCH 0 +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + // The redundant mov PTX instruction is used to enforce the compiler to // initialize data to zero before ld.global -template +template struct global_load { @@ -62,55 +72,61 @@ struct global_load(&D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %9, 0;\n" - " mov.b32 %0, %10;\n" - " mov.b32 %1, %11;\n" - " mov.b32 %2, %12;\n" - " mov.b32 %3, %13;\n" - " mov.b32 %4, %14;\n" - " mov.b32 %5, %15;\n" - " mov.b32 %6, %16;\n" - " mov.b32 %7, %17;\n" - " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" - " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n" - "}\n" - : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), - "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) - : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y), - "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y), - "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %9, 0;\n" + " mov.b32 %0, %10;\n" + " mov.b32 %1, %11;\n" + " mov.b32 %2, %12;\n" + " mov.b32 %3, %13;\n" + " mov.b32 %4, %14;\n" + " mov.b32 %5, %15;\n" + " mov.b32 %6, %16;\n" + " mov.b32 %7, %17;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%8];\n" + " @p ld.global.L2::128B.v4.u32 {%4, %5, %6, %7}, [%18];\n" +#else + " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" + " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n" +#endif + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), + "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) + : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y), + "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y), + "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16)); } }; -template +template struct global_load { CUTLASS_DEVICE global_load(AccessType &D, void const *ptr, bool pred_guard) { uint4 &data = reinterpret_cast(D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %5, 0;\n" - " mov.b32 %0, %6;\n" - " mov.b32 %1, %7;\n" - " mov.b32 %2, %8;\n" - " mov.b32 %3, %9;\n" - " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" - "}\n" - : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) - : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " mov.b32 %0, %6;\n" + " mov.b32 %1, %7;\n" + " mov.b32 %2, %8;\n" + " mov.b32 %3, %9;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%4];\n" +#else + " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" +#endif + "}\n" + : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) + : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); } }; -template +template struct global_load { @@ -118,21 +134,24 @@ struct global_load(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %3, 0;\n" - " mov.b32 %0, %4;\n" - " mov.b32 %1, %5;\n" - " @p ld.global.v2.u32 {%0, %1}, [%2];\n" - "}\n" - : "=r"(data.x), "=r"(data.y) - : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + " mov.b32 %0, %4;\n" + " mov.b32 %1, %5;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p ld.global.L2::128B.v2.u32 {%0, %1}, [%2];\n" +#else + " @p ld.global.v2.u32 {%0, %1}, [%2];\n" +#endif + "}\n" + : "=r"(data.x), "=r"(data.y) + : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y)); } }; -template +template struct global_load { @@ -140,20 +159,23 @@ struct global_load(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " mov.b32 %0, %3;\n" - " @p ld.global.u32 %0, [%1];\n" - "}\n" - : "=r"(data) - : "l"(ptr), "r"((int)pred_guard), "r"(data)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b32 %0, %3;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p ld.global.L2::128B.u32 %0, [%1];\n" +#else + " @p ld.global.u32 %0, [%1];\n" +#endif + "}\n" + : "=r"(data) + : "l"(ptr), "r"((int)pred_guard), "r"(data)); } }; -template +template struct global_load { @@ -161,20 +183,23 @@ struct global_load(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " mov.b16 %0, %3;\n" - " @p ld.global.u16 %0, [%1];\n" - "}\n" - : "=h"(data) - : "l"(ptr), "r"((int)pred_guard), "h"(data)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b16 %0, %3;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p ld.global.L2::128B.u16 %0, [%1];\n" +#else + " @p ld.global.u16 %0, [%1];\n" +#endif + "}\n" + : "=h"(data) + : "l"(ptr), "r"((int)pred_guard), "h"(data)); } }; -template +template struct global_load { diff --git a/include/cutlass/arch/memory_sm80.h b/include/cutlass/arch/memory_sm80.h index 1b5bb10b..c93136bd 100644 --- a/include/cutlass/arch/memory_sm80.h +++ b/include/cutlass/arch/memory_sm80.h @@ -30,6 +30,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/arch/cache_operation.h" @@ -90,7 +91,11 @@ struct cp_async { "{\n" " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" +#else " @p cp.async.ca.shared.global [%1], [%2], %3;\n" +#endif "}\n" ::"r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); @@ -123,7 +128,11 @@ struct cp_async_zfill { int src_in_bytes = (pred_guard ? SizeInBytes : 0); asm volatile( +#if CUTLASS_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), +#else "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), +#endif "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); #else @@ -163,7 +172,11 @@ struct cp_async { "{\n" " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" +#else " @p cp.async.cg.shared.global [%1], [%2], %3;\n" +#endif "}\n" ::"r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); @@ -195,7 +208,11 @@ struct cp_async_zfill { int src_in_bytes = (pred_guard ? SizeInBytes : 0); asm volatile( +#if CUTLASS_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), +#else "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), +#endif "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); #else diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index 1672e607..1367ab73 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -30,6 +30,7 @@ #include "cutlass/array.h" #include "cutlass/numeric_types.h" +#include "cutlass/functional.h" #include "cutlass/gemm/gemm.h" #include "cutlass/arch/arch.h" @@ -130,11 +131,12 @@ template < /// Layout of C matrix (concept: MatrixLayout) typename LayoutC, /// Inner product operator - typename Operator + typename Operator_ > -struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator> { +struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator_> { using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = Operator_; CUTLASS_HOST_DEVICE void operator()( @@ -144,7 +146,9 @@ struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, El Array const &c ) { - d[0] = a[0] * b[0] + c[0]; + multiply_add op; + + d[0] = op(a[0], b[0], c[0]); } }; diff --git a/include/cutlass/arch/mma_sm50.h b/include/cutlass/arch/mma_sm50.h index fa8e1949..0d47d88b 100644 --- a/include/cutlass/arch/mma_sm50.h +++ b/include/cutlass/arch/mma_sm50.h @@ -30,6 +30,8 @@ #include "cutlass/arch/mma.h" #include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/functional.h" #include "cutlass/layout/matrix.h" #include "cutlass/gemm/gemm.h" @@ -379,5 +381,35 @@ struct Mma, 1, half_t, LayoutA, half_t, LayoutB, float, ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Matrix multiply-add operation for Quaternions +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, Quaternion, LayoutA, Quaternion, LayoutB, Quaternion, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using Element = Quaternion; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + multiply_add op; + d[0] = op(a[0], b[0], c[0]); + } + +}; + } } + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma.h b/include/cutlass/arch/wmma.h index fa6d288a..01d3a06d 100644 --- a/include/cutlass/arch/wmma.h +++ b/include/cutlass/arch/wmma.h @@ -29,7 +29,7 @@ #pragma once // CUTLASS WMMA does not support clang at present. -#if !defined(__clang__) +#if !(defined(__clang__) && defined(__CUDA__)) #if (__CUDACC_VER_MAJOR__ >= 9) #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700)) @@ -52,7 +52,7 @@ #endif #endif -#endif //!defined(__clang__) +#endif //!(defined(__clang__) && defined(__CUDA__)) #if defined(CUTLASS_ARCH_WMMA_ENABLED) diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 4eee9960..28971db1 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -49,7 +49,7 @@ class Array; template struct sizeof_bits > { static int const value = - sizeof(typename Array::Storage) * 8 * Array::kStorageElements; + int(sizeof(typename Array::Storage)) * 8 * int(Array::kStorageElements); }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index 81008df7..a810b9bd 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -62,7 +62,7 @@ public: using Element = T; /// Number of logical elements per stored object - static int const kElementsPerStoredItem = (sizeof(Storage) * 8) / sizeof_bits::value; + static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; /// Number of storage elements static size_t const kStorageElements = N / kElementsPerStoredItem; diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index fc32a509..cc1e167c 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -33,6 +33,7 @@ #include #include #include +#include #endif #include "cutlass/cutlass.h" @@ -76,7 +77,13 @@ struct alignas(2) bfloat16_t { asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x)); #else - uint32_t bits = reinterpret_cast(x); + uint32_t bits; + + #if defined(__CUDA_ARCH__) + bits = reinterpret_cast(x); + #else + std::memcpy(&bits, &x, sizeof(bits)); + #endif if ((bits & 0x7f800000) != 0x7f800000) { @@ -106,14 +113,28 @@ struct alignas(2) bfloat16_t { CUTLASS_HOST_DEVICE explicit bfloat16_t(int x) { float flt = static_cast(x); - storage = uint16_t(reinterpret_cast(flt) >> 16); + uint32_t bits; + + #if defined(__CUDA_ARCH__) + bits = reinterpret_cast(flt); + #else + std::memcpy(&bits, &flt, sizeof(bits)); + #endif + + storage = uint16_t(bits >> 16); } /// Converts to float CUTLASS_HOST_DEVICE operator float() const { unsigned bits = (unsigned(storage) << 16); + #if defined(__CUDA_ARCH__) return reinterpret_cast(bits); + #else + float flt; + std::memcpy(&flt, &bits, sizeof(flt)); + return flt; + #endif } /// Converts to float @@ -237,11 +258,22 @@ cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) { CUTLASS_HOST_DEVICE bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) { - uint16_t a_mag = (reinterpret_cast(a) & 0x7fff); - uint16_t b_sign = (reinterpret_cast(b) & 0x8000); + uint16_t a_bits; + uint16_t b_bits; + + #if defined(__CUDA_ARCH__) + a_bits = reinterpret_cast(a); + b_bits = reinterpret_cast(b); + #else + std::memcpy(&a_bits, &a, sizeof(a_bits)); + std::memcpy(&b_bits, &b, sizeof(b_bits)); + #endif + + uint16_t a_mag = (a_bits & 0x7fff); + uint16_t b_sign = (b_bits & 0x8000); uint16_t result = (a_mag | b_sign); - return reinterpret_cast(result); + return bfloat16_t::bitcast(result); } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index 3312619c..3374c657 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -38,6 +38,8 @@ #include "cutlass/bfloat16.h" #include "cutlass/tfloat32.h" +#include "cutlass/fast_math.h" + #if !defined(__CUDACC_RTC__) #include #endif @@ -442,16 +444,16 @@ CUTLASS_HOST_DEVICE complex polar(T const &r, T const &theta = T()) { /// Computes the complex exponential of z. template CUTLASS_HOST_DEVICE complex exp(complex const &z) { - return complex(real(z) * cos(imag(z)), real(z) * sin(imag(z))); + return complex(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z))); } -/// Computes the complex exponential of z. +/// Computes the log of z template CUTLASS_HOST_DEVICE complex log(complex const &z) { return complex(log(abs(z)), arg(z)); } -/// Computes the complex exponential of z. +/// Computes the log base 10 of z template CUTLASS_HOST_DEVICE complex log10(complex const &z) { return log(z) / T(log(T(10))); @@ -484,6 +486,9 @@ template struct RealType< complex > { using Type = T; + /// Number of elements + static int const kExtent = 2; + CUTLASS_HOST_DEVICE static complex from_real(double x) { return complex(static_cast(x)); diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index fd87e1ac..4426ece6 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -284,6 +284,27 @@ public: return cutlass::MatrixCoord ({dilation_h, dilation_w}); } + + ///////////////////////////////////////////////////////////////// + // Methods used for strided dgrad implementation + ///////////////////////////////////////////////////////////////// + /// Number of filter r positions to accumulate in gemm-k dim + CUTLASS_HOST_DEVICE + int num_gemm_k_filter_r(int r) const { + return ((R - r + stride_h - 1) / stride_h); + } + + /// Number of filter s positions to accumulate in gemm-k dim + CUTLASS_HOST_DEVICE + int num_gemm_k_filter_s(int s) const { + return ((S - s + stride_w - 1) / stride_w); + } + + /// Number of filter positions to accumulate in gemm-k dim + CUTLASS_HOST_DEVICE + int num_gemm_k_filter_positions(int r, int s) const { + return num_gemm_k_filter_r(r) * num_gemm_k_filter_s(s); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -444,6 +465,27 @@ int64_t implicit_gemm_tensor_c_size( //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Strided dgrad helper functions // +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Returns number of CTAs tile M to cover valid MMAs per starting filter postion +CUTLASS_HOST_DEVICE +int strided_dgrad_tile_m_per_filter( + Conv2dProblemSize const &problem_size, + int tile_size_m) { + + // Compute NHW rows in Dx output that needs MMA per starting filter position + int rows_h_per_filter = (problem_size.H + problem_size.stride_h - 1) / problem_size.stride_h; + int rows_w_per_filter = (problem_size.W + problem_size.stride_w - 1) / problem_size.stride_w; + int rows_nhw_per_filter = problem_size.N * rows_h_per_filter * rows_w_per_filter; + + // Number of CTAs tile M to cover valid MMAs per starting filter postion + int tile_m_per_filter = (rows_nhw_per_filter + tile_size_m - 1) / tile_size_m; + + return tile_m_per_filter; +} + + } // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/convolution.h b/include/cutlass/conv/convolution.h index 95afe94f..f4873d82 100644 --- a/include/cutlass/conv/convolution.h +++ b/include/cutlass/conv/convolution.h @@ -115,4 +115,3 @@ enum class SplitKMode { } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index 5535b09a..ba1572e2 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -71,6 +71,7 @@ public: static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmKernel::kConvolutionalOperator; static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmKernel::kIteratorAlgorithm; + static cutlass::conv::StrideSupport const kStrideSupport = ImplicitGemmKernel::kStrideSupport; static int const kWarpCount = (ThreadblockShape::kM / WarpShape::kM) * @@ -104,12 +105,37 @@ public: return status; } + // check for unsupported problem sizes for strided dgrad implementation + if (kConvolutionalOperator == conv::Operator::kDgrad && + kStrideSupport == conv::StrideSupport::kStrided) { + + // Unity stride (1x1) is supported by strided dgrad but disabled for performance + // reasons. For unity stride, use strided dgrad optimized unity stride specialization. + // Note that unit tests strided dgrad for unity stride to make sure that strided + // dgrad implemetnation is functionaly sound. + // Strided dgrad implementation also support mixed strides, i.e., (1x2) and (2x1) + if(args.problem_size.stride_h == 1 && args.problem_size.stride_w == 1) { + return Status::kErrorNotSupported; + } + + // split-k (serial or parallel) is not supported for strided dgrad + if(args.problem_size.split_k_slices > 1) { + return Status::kErrorNotSupported; + } + + // dilation > {1x1} is not supported for strided dgrad + if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) { + return Status::kErrorNotSupported; + } + } + // Determine grid shape ThreadblockSwizzle threadblock_swizzle; dim3 grid = threadblock_swizzle.get_grid_shape( threadblock_swizzle.get_tiled_shape( - cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), + kConvolutionalOperator, + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.problem_size.split_k_slices)); @@ -131,7 +157,8 @@ public: ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), + kConvolutionalOperator, + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.problem_size.split_k_slices); @@ -220,6 +247,7 @@ public: /// Runs the kernel using initialized state. Status run(cudaStream_t stream = nullptr) { + ThreadblockSwizzle threadblock_swizzle; dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); diff --git a/include/cutlass/conv/kernel/default_conv2d.h b/include/cutlass/conv/kernel/default_conv2d.h index 603856a4..c804dc56 100644 --- a/include/cutlass/conv/kernel/default_conv2d.h +++ b/include/cutlass/conv/kernel/default_conv2d.h @@ -33,6 +33,7 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/conv/threadblock/threadblock_swizzle.h" #include "cutlass/epilogue/threadblock/default_epilogue_simt.h" #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" @@ -41,6 +42,9 @@ #include "cutlass/conv/threadblock/implicit_gemm_pipelined.h" #include "cutlass/conv/threadblock/implicit_gemm_multistage.h" #include "cutlass/conv/kernel/implicit_gemm_convolution.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h" + + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -62,7 +66,7 @@ struct DefaultConvEpilogue { using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< Shape, WarpMmaTensorOp, - 1, + PartitionsK, OutputOp, OutputOp::kCount >::Epilogue; @@ -85,7 +89,49 @@ struct DefaultConvEpilogue< using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOp< Shape, WarpMmaTensorOp, - 1, + PartitionsK, + OutputOp, + OutputOp::kCount + >::Epilogue; +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Defaults for strided Dgrad +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename OutputOp +> +struct DefaultConvEpilogueStridedDgrad { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + OutputOp::kCount + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename OutputOp +> +struct DefaultConvEpilogueStridedDgrad< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp +> { + + using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOpStridedDgrad< + Shape, + WarpMmaTensorOp, + PartitionsK, OutputOp, OutputOp::kCount >::Epilogue; diff --git a/include/cutlass/conv/kernel/default_conv2d_dgrad.h b/include/cutlass/conv/kernel/default_conv2d_dgrad.h index f81c3897..53395e41 100644 --- a/include/cutlass/conv/kernel/default_conv2d_dgrad.h +++ b/include/cutlass/conv/kernel/default_conv2d_dgrad.h @@ -35,7 +35,7 @@ #include "cutlass/conv/kernel/default_conv2d.h" #include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" -#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" #include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h" #include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h" #include "cutlass/conv/threadblock/conv2d_tile_iterator.h" @@ -83,7 +83,6 @@ template < typename ElementC, typename LayoutC, typename ElementAccumulator, - typename OperatorClass, typename ArchTag, typename ThreadblockShape, typename WarpShape, @@ -101,7 +100,7 @@ struct DefaultConv2dDgrad < ElementC, LayoutC, ElementAccumulator, - OperatorClass, + arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, @@ -117,7 +116,7 @@ struct DefaultConv2dDgrad < // Define the core components from GEMM using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, Stages, MathOperatorTag>; // Define iterators over tiles from the A operand @@ -138,7 +137,8 @@ struct DefaultConv2dDgrad < cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + StrideSupport::kStrided >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -160,17 +160,19 @@ struct DefaultConv2dDgrad < Stages >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue - using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue; // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< Mma, Epilogue, ThreadblockSwizzle, @@ -188,7 +190,6 @@ template < typename ElementC, typename LayoutC, typename ElementAccumulator, - typename OperatorClass, typename ArchTag, typename ThreadblockShape, typename WarpShape, @@ -205,7 +206,7 @@ struct DefaultConv2dDgrad < ElementC, LayoutC, ElementAccumulator, - OperatorClass, + arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, @@ -221,13 +222,13 @@ struct DefaultConv2dDgrad < // Define the core components from GEMM using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, MathOperatorTag>; // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; using IteratorA = - cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::TileIteratorStridedDgrad< cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementA, @@ -241,11 +242,12 @@ struct DefaultConv2dDgrad < // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; using IteratorB = - cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::TileIteratorStridedDgrad< cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + StrideSupport::kStrided > >; @@ -267,17 +269,19 @@ struct DefaultConv2dDgrad < MmaPolicy >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue - using Epilogue = typename detail::DefaultConvEpilogue< + using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< ArchTag, ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp >::Epilogue; // Define the kernel - using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< Mma, Epilogue, ThreadblockSwizzle, @@ -297,7 +301,6 @@ template < typename ElementC, typename LayoutC, typename ElementAccumulator, - typename OperatorClass, typename ArchTag, typename ThreadblockShape, typename WarpShape, @@ -315,7 +318,7 @@ struct DefaultConv2dDgrad < ElementC, LayoutC, ElementAccumulator, - OperatorClass, + arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, @@ -331,7 +334,7 @@ struct DefaultConv2dDgrad < // Define the core components from GEMM using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, Stages, MathOperatorTag>; // Define iterators over tiles from the A operand @@ -352,7 +355,8 @@ struct DefaultConv2dDgrad < cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + StrideSupport::kUnity >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -374,11 +378,13 @@ struct DefaultConv2dDgrad < Stages >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue; @@ -402,7 +408,6 @@ template < typename ElementC, typename LayoutC, typename ElementAccumulator, - typename OperatorClass, typename ArchTag, typename ThreadblockShape, typename WarpShape, @@ -419,7 +424,7 @@ struct DefaultConv2dDgrad < ElementC, LayoutC, ElementAccumulator, - OperatorClass, + arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, @@ -435,7 +440,7 @@ struct DefaultConv2dDgrad < // Define the core components from GEMM using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, MathOperatorTag>; // Define iterators over tiles from the A operand @@ -459,7 +464,8 @@ struct DefaultConv2dDgrad < cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + StrideSupport::kUnity > >; @@ -481,12 +487,14 @@ struct DefaultConv2dDgrad < MmaPolicy >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename detail::DefaultConvEpilogue< ArchTag, ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp >::Epilogue; @@ -511,7 +519,6 @@ template < typename ElementC, typename LayoutC, typename ElementAccumulator, - typename OperatorClass, typename ArchTag, typename ThreadblockShape, typename WarpShape, @@ -529,7 +536,7 @@ struct DefaultConv2dDgrad < ElementC, LayoutC, ElementAccumulator, - OperatorClass, + arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, @@ -545,7 +552,7 @@ struct DefaultConv2dDgrad < // Define the core components from GEMM using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, Stages, MathOperatorTag>; // Define iterators over tiles from the A operand @@ -588,11 +595,13 @@ struct DefaultConv2dDgrad < Stages >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue; @@ -616,7 +625,6 @@ template < typename ElementC, typename LayoutC, typename ElementAccumulator, - typename OperatorClass, typename ArchTag, typename ThreadblockShape, typename WarpShape, @@ -633,7 +641,7 @@ struct DefaultConv2dDgrad < ElementC, LayoutC, ElementAccumulator, - OperatorClass, + arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, @@ -649,7 +657,7 @@ struct DefaultConv2dDgrad < // Define the core components from GEMM using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, - ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, MathOperatorTag>; // Define iterators over tiles from the A operand @@ -695,12 +703,14 @@ struct DefaultConv2dDgrad < MmaPolicy >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename detail::DefaultConvEpilogue< ArchTag, ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp >::Epilogue; @@ -734,8 +744,7 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag -> + typename MathOperatorTag> struct DefaultConv2dDgrad < ElementA, LayoutA, @@ -754,7 +763,7 @@ struct DefaultConv2dDgrad < Stages, MathOperatorTag, IteratorAlgorithm::kAnalytic, - StrideSupport::kStrided + conv::StrideSupport::kUnity > { // Define the core components from GEMM @@ -770,7 +779,7 @@ struct DefaultConv2dDgrad < cutlass::MatrixShape, ElementA, ThreadMapA, - StrideSupport::kStrided + conv::StrideSupport::kUnity >; using SmemIteratorA = typename MmaCore::SmemIteratorA; @@ -781,7 +790,8 @@ struct DefaultConv2dDgrad < cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + conv::StrideSupport::kUnity >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -823,6 +833,110 @@ struct DefaultConv2dDgrad < ///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kStrided + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Defines a kernel for Conv2dDgrad specialzation for Optimized IteratorAlgorithm, /// multi-stage pipeline, and FFMA-based mainloop for SM80 @@ -888,7 +1002,8 @@ struct DefaultConv2dDgrad < cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + StrideSupport::kUnity >; using SmemIteratorB = typename MmaCore::SmemIteratorB; @@ -928,6 +1043,8 @@ struct DefaultConv2dDgrad < }; + +///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines a kernel for Conv2dDgrad specialzation for Analytic IteratorAlgorithm, @@ -966,7 +1083,7 @@ struct DefaultConv2dDgrad < 2, MathOperatorTag, IteratorAlgorithm::kAnalytic, - StrideSupport::kStrided + conv::StrideSupport::kUnity > { // Define the core components from GEMM @@ -983,7 +1100,7 @@ struct DefaultConv2dDgrad < cutlass::MatrixShape, ElementA, ThreadMapA, - StrideSupport::kStrided + conv::StrideSupport::kUnity > >; @@ -996,7 +1113,8 @@ struct DefaultConv2dDgrad < cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + conv::StrideSupport::kUnity > >; @@ -1034,6 +1152,112 @@ struct DefaultConv2dDgrad < conv::Operator::kDgrad >; +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -1104,7 +1328,8 @@ struct DefaultConv2dDgrad < cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< cutlass::MatrixShape, ElementB, - ThreadMapB + ThreadMapB, + StrideSupport::kUnity > >; @@ -1144,8 +1369,6 @@ struct DefaultConv2dDgrad < }; ///////////////////////////////////////////////////////////////////////////////////////////////// - - } // namespace kernel } // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop.h b/include/cutlass/conv/kernel/default_conv2d_fprop.h index d22fb7f0..e8d8b844 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -157,11 +157,13 @@ struct DefaultConv2dFprop < Stages >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue; @@ -271,11 +273,13 @@ struct DefaultConv2dFprop < Stages >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount, InterleavedK @@ -378,12 +382,14 @@ struct DefaultConv2dFprop < MmaPolicy >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename detail::DefaultConvEpilogue< ArchTag, ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp >::Epilogue; @@ -494,11 +500,13 @@ struct DefaultConv2dFprop < MmaPolicy >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount, InterleavedK @@ -602,11 +610,13 @@ struct DefaultConv2dFprop < Stages >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue; @@ -708,11 +718,13 @@ struct DefaultConv2dFprop < Stages >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount, InterleavedK @@ -817,12 +829,14 @@ struct DefaultConv2dFprop < MmaPolicy >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename detail::DefaultConvEpilogue< ArchTag, ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp >::Epilogue; @@ -923,11 +937,13 @@ struct DefaultConv2dFprop < MmaPolicy >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount, InterleavedK diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h new file mode 100644 index 00000000..13dcb8ad --- /dev/null +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> +struct DefaultConv2dFpropWithBroadcast { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + ElementC, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h new file mode 100644 index 00000000..23989682 --- /dev/null +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" +#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename EpilogueReductionOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> +struct DefaultConv2dFpropWithReduction { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + EpilogueOutputOp, + EpilogueReductionOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv2d_wgrad.h b/include/cutlass/conv/kernel/default_conv2d_wgrad.h index 1bb68689..69986b20 100644 --- a/include/cutlass/conv/kernel/default_conv2d_wgrad.h +++ b/include/cutlass/conv/kernel/default_conv2d_wgrad.h @@ -160,11 +160,13 @@ struct DefaultConv2dWgrad < Stages >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue; @@ -266,12 +268,14 @@ struct DefaultConv2dWgrad < MmaPolicy >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename detail::DefaultConvEpilogue< ArchTag, ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp >::Epilogue; @@ -371,11 +375,13 @@ struct DefaultConv2dWgrad < Stages >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue; @@ -477,12 +483,14 @@ struct DefaultConv2dWgrad < MmaPolicy >; + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + // Define the epilogue using Epilogue = typename detail::DefaultConvEpilogue< ArchTag, ThreadblockShape, WarpMmaTensorOp, - 1, + kPartitionsK, EpilogueOutputOp >::Epilogue; diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/include/cutlass/conv/kernel/implicit_gemm_convolution.h index fbc44b15..ae7e024e 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -92,7 +92,8 @@ struct ImplicitGemmConvolution { static int const kStages = Mma::kStages; static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; - + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + /// Warp count (concept: GemmShape) using WarpCount = typename Mma::WarpCount; static int const kThreadCount = 32 * WarpCount::kCount; @@ -188,6 +189,8 @@ struct ImplicitGemmConvolution { ConvProblemSize problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + int gemm_k_iterations; typename Mma::IteratorA::Params iterator_A; typename Mma::IteratorA::Element const *ptr_A; @@ -206,7 +209,7 @@ struct ImplicitGemmConvolution { // CUTLASS_HOST_DEVICE - Params(): gemm_k_iterations(0) { } + Params(): swizzle_log_tile(0), gemm_k_iterations(0) { } /// CUTLASS_HOST_DEVICE @@ -236,6 +239,8 @@ struct ImplicitGemmConvolution { implicit_gemm_problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); } }; @@ -260,7 +265,7 @@ struct ImplicitGemmConvolution { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || @@ -327,7 +332,7 @@ struct ImplicitGemmConvolution { // Compute logical position within grid threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // If performing a reduction via split-K, fetch the initial synchronization if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h new file mode 100644 index 00000000..a4a12bee --- /dev/null +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -0,0 +1,461 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionStridedDgrad { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + cutlass::platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // Strided dgrad uses a specialized threadblock swizzle for functionality and performance + static_assert((std::is_same::value) || + (std::is_same>::value) || + (std::is_same>::value) || + (std::is_same>::value), + "Needs ThreadblockSwizzle type specialized for strided dgrad"); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + split_k_mode(split_k_mode) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + FastDivmod filter_s_divmod; + int gemm_k_iterations; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): gemm_k_iterations(0) { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + filter_s_divmod(args.problem_size.stride_w), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size, ThreadblockShape::kM), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size, ThreadblockShape::kM), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode) + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionStridedDgrad() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Compute starting filter position for strided dgrad + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(params.problem_size, + ThreadblockShape::kM); + int filter_tile_m = (threadblock_tile_idx.m() / tile_m_per_filter); + + + // The subsequent fast_divmod() operations are equivalent to the following logical computation: + // + // int start_r = filter_tile_m / (params.problem_size.stride_w); + // int start_s = filter_tile_m % (params.problem_size.stride_w); + + int start_r, start_s; + params.filter_s_divmod(start_r, start_s, filter_tile_m); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // Check if CTA contributes valid MMA (Dy * w) and accumulator will be non-zero after MMA + if (start_r < params.problem_size.R && start_s < params.problem_size.S) { + // Scale gemm_k_iterations for strided dgrad + int gemm_k_iterations = (params.gemm_k_iterations / (params.problem_size.R * params.problem_size.S) + ) * params.problem_size.num_gemm_k_filter_positions(start_r, start_s); + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + start_r, start_s, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + start_r, start_s, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + start_r, start_s, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + start_r, start_s, + threadblock_offset + ); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + __threadfence(); + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Run efficient epilogue + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h new file mode 100644 index 00000000..1e8832ec --- /dev/null +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -0,0 +1,493 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionWithFusedEpilogue { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + cutlass::platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + void * ptr_Vector; + void * ptr_Tensor; + + typename LayoutC::Stride::Index ldr; + typename LayoutC::Stride::Index ldt; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial, + void * ptr_Vector = nullptr, + void * ptr_Tensor = nullptr, + typename LayoutC::Stride::Index ldr = 0, + typename LayoutC::Stride::Index ldt = 0 + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + split_k_mode(split_k_mode), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + ldr(ldr), + ldt(ldt) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + + int gemm_k_iterations; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + + typename Epilogue::TensorTileIterator::Params params_Tensor; + void * ptr_Vector; + typename LayoutC::Stride::Index ldr; + void * ptr_Tensor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + gemm_k_iterations(0), + ptr_Vector(nullptr), + ldr(0), + ptr_Tensor(nullptr) + { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode), + params_Tensor(args.ldt), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr), + ptr_Tensor(args.ptr_Tensor) + + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionWithFusedEpilogue() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + typename Epilogue::ElementTensor *ptr_Tensor = + static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + // Only the final block outputs Tensor + ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) + ? nullptr + : ptr_Tensor, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr; + } + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + __threadfence(); + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + // Only the final block uses Vector + ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) + ? nullptr + : ptr_Vector, + iterator_D, + accumulators, + iterator_C, + tensor_iterator, + ConvOutputIteratorParameter::extent(params.problem_size), + threadblock_offset); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h index 8afb4968..49083a5d 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h @@ -55,12 +55,29 @@ namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity +> +class Conv2dDgradFilterTileAccessIteratorAnalytic; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorAnalytic strided dgrad needs special handling to skip MMAs +// on non-contributing w positions template < typename Shape_, typename Element_, typename ThreadMap_ > -class Conv2dDgradFilterTileAccessIteratorAnalytic { +class Conv2dDgradFilterTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided +> { public: // @@ -90,6 +107,197 @@ public: using Params = Conv2dAnalyticParams; +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension + int filter_r_; + int filter_s_; + int start_r_; + int start_s_; + int offset_k_[ThreadMap::Iterations::kStrided]; + int offset_c_[ThreadMap::Iterations::kContiguous]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_r_(start_r), + filter_s_(start_s), + start_r_(start_r), + start_s_(start_s) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = + threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Moves filter_s + filter_s_ += problem_size_.stride_w; + if (filter_s_ < problem_size_.S) { + return; + } + // Restore filter_s + filter_s_ = start_s_; + + // Move filter_r + filter_r_ += problem_size_.stride_h; + if (filter_r_ < problem_size_.R) { + return; + } + // Restore filter_r + filter_r_ = start_r_; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the filter tensor w that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int c = offset_c_[iteration_contiguous_]; + int k = offset_k_[iteration_strided_]; + + return TensorCoord(k, filter_r_, filter_s_, c); + } + + /// Returns true if the current coordinate is within the filter tensor w + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % (128/sizeof_bits::value)) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorAnalytic unity strided dgrad is more performant for dgrad +// on problem sizes with stride = {1x1} +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv2dDgradFilterTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity +>{ +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or larger."); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + private: Params const ¶ms_; diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h index 937216d5..b2351804 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h @@ -62,7 +62,23 @@ template < typename ThreadMap_, conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity > -class Conv2dDgradFilterTileAccessIteratorOptimized { +class Conv2dDgradFilterTileAccessIteratorOptimized; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad +// on problem sizes with stride = {1x1} +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv2dDgradFilterTileAccessIteratorOptimized < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity + > { public: // @@ -79,7 +95,7 @@ public: using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; - static StrideSupport const kStrideSupport = StrideSupport_; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h index e33e4ccb..f33a6e12 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/coord.h" +#include "cutlass/functional.h" #include "cutlass/predicate_vector.h" #include "cutlass/tensor_ref.h" #include "cutlass/tensor_view.h" @@ -109,7 +110,7 @@ public: // Parameters structure // - using Params = Conv2dAnalyticParams; + using Params = Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams; private: @@ -122,36 +123,13 @@ private: int filter_k_; int filter_r_; int filter_s_; + int start_r_; + int start_s_; int offset_n_[ThreadMap::Iterations::kStrided]; - int offset_w_[ThreadMap::Iterations::kStrided]; - int offset_h_[ThreadMap::Iterations::kStrided]; - -private: + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; - /// Returns the coordinate in the output tensor Dy that is currently pointed to - /// by the iterator but DOES NOT scale by the convolution stride. This is needed - /// to compute predicates in the valid() method. The return value of the public at() - /// method is correctly scaled. - CUTLASS_HOST_DEVICE - TensorCoord unscaled_at_() const { - int n = offset_n_[iteration_strided_]; - int h = offset_h_[iteration_strided_]; - int w = offset_w_[iteration_strided_]; - - int r = filter_r_; - int s = filter_s_; - - if (problem_size_.mode == Mode::kConvolution) { - r = (problem_size_.R - 1 - r); - s = (problem_size_.S - 1 - s); - } - - int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h); - int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w); - - return TensorCoord(n, p, q, filter_k_); - } public: @@ -161,34 +139,68 @@ public: Conv2dProblemSize const &problem_size, Element const *ptr, int thread_idx, + int start_r, int start_s, MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles ): params_(params), problem_size_(problem_size), pointer_(reinterpret_cast(ptr)), - filter_k_(0), - filter_r_(0), - filter_s_(0) { + filter_k_(0), + filter_r_(start_r), + filter_s_(start_s), + start_r_(start_r), + start_s_(start_s) { layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + int filter_r = filter_r_; + int filter_s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + filter_r = (problem_size_.R - 1 - filter_r); + filter_s = (problem_size_.S - 1 - filter_s); + } + + // Starting h, w positions for filter position in gemm_k=0 + int start_h = std::abs((problem_size_.pad_h - filter_r) % problem_size_.stride_h); + int start_w = std::abs((problem_size_.pad_w - filter_s) % problem_size_.stride_w); + + + // Effective P and Q for filter position required for remapping NHW rows + int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; + int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; + + CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; - offset_n_[s] = offset_nhw / (problem_size_.H * problem_size_.W); - int residual = offset_nhw % (problem_size_.H * problem_size_.W); + // (STEP 1) [reorder NHW rows to start with same filter positions] + offset_n_[s] = offset_npq / (P * Q); + int residual = offset_npq % (P * Q); - offset_h_[s] = residual / problem_size_.W; - offset_w_[s] = residual % problem_size_.W; + int p = (residual / Q); + int q = (residual % Q); + + int mapped_h = (start_h + p * problem_size_.stride_h); + int mapped_w = (start_w + q * problem_size_.stride_w); + + // Access (p, q) coordinates for Dy tensor and a filter position in gemm_k=0 + // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are divisible + // by stride_h and stride_w + offset_p_[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; + offset_q_[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; } } CUTLASS_HOST_DEVICE static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { - return Params(problem_size, layout); + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}); } /// Overrides the internal iteration index @@ -206,18 +218,26 @@ public: CUTLASS_HOST_DEVICE void advance() { - // move to the next tile - ++filter_s_; + + // Move filter_s by stride_w + filter_s_ += problem_size_.stride_w; if (filter_s_ < problem_size_.S) { return; } - filter_s_ = 0; - ++filter_r_; + + // Restore filter_s + filter_s_ = start_s_; + + // Move filter_r by stride_h + filter_r_ += problem_size_.stride_h; if (filter_r_ < problem_size_.R) { return; } - filter_r_ = 0; + // Restore filter_r + filter_r_ = start_r_; + + // Move filter_k filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; } @@ -225,14 +245,20 @@ public: /// by the iterator. CUTLASS_HOST_DEVICE TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int conv_sign = (problem_size_.mode == Mode::kConvolution ? 1 : -1); - TensorCoord coord = unscaled_at_(); + p += (conv_sign * (filter_r_ / problem_size_.stride_h)); + q += (conv_sign * (filter_s_ / problem_size_.stride_w)); return TensorCoord( - coord.n(), - coord.h() / problem_size_.stride_h, - coord.w() / problem_size_.stride_w, - coord.c()); + n, + p, + q, + filter_k_); } @@ -240,11 +266,9 @@ public: CUTLASS_HOST_DEVICE bool valid() const { - TensorCoord unscaled_coord = unscaled_at_(); TensorCoord coord = at(); return - !(unscaled_coord.h() % problem_size_.stride_h) && !(unscaled_coord.w() % problem_size_.stride_w) && coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.P && coord.w() >= 0 && coord.w() < problem_size_.Q && diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h index 078c9e7f..009f5a72 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -32,6 +32,7 @@ backward data gradient (Dgrad), and backward weight gradient (Wgrad). */ + #pragma once #include "cutlass/cutlass.h" @@ -62,11 +63,26 @@ template < typename ThreadMap_, conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity > -class Conv2dDgradOutputGradientTileAccessIteratorOptimized { -public: +class Conv2dDgradOutputGradientTileAccessIteratorOptimized; +///////////////////////////////////////////////////////////////////////////////////////////////// - static_assert(StrideSupport_ == conv::StrideSupport::kUnity, - "Only unit-stride dgrad is supported at this time."); +///////////////////////////////////////////////////////////////////////////////////////////////// +// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is optimized for dgrad +// with problem stride = {1x1} +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv2dDgradOutputGradientTileAccessIteratorOptimized < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity +> { +public: // // Types @@ -417,5 +433,3 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h index 573255da..b8795fce 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -99,7 +99,7 @@ public: private: - Conv2dFpropActivationIteratorOptimizedParams const ¶ms_; + Params const ¶ms_; Conv2dProblemSize const &problem_size_; LongIndex iteration_contiguous_; LongIndex iteration_strided_; @@ -118,7 +118,7 @@ public: CUTLASS_HOST_DEVICE Conv2dFpropActivationTileAccessIteratorOptimized( - Conv2dFpropActivationIteratorOptimizedParams const ¶ms, + Params const ¶ms, Conv2dProblemSize const &problem_size, Element const *ptr, int thread_idx, diff --git a/include/cutlass/conv/threadblock/conv2d_params.h b/include/cutlass/conv/threadblock/conv2d_params.h index 3c64b1f7..a50755fa 100644 --- a/include/cutlass/conv/threadblock/conv2d_params.h +++ b/include/cutlass/conv/threadblock/conv2d_params.h @@ -77,6 +77,38 @@ struct Conv2dAnalyticParams { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Parameters structure used for Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams +struct Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + int tiled_rows_per_filter; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams() { } + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape + ): layout(layout) { + + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); + + tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + #if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED CUTLASS_HOST_DEVICE @@ -199,6 +231,32 @@ struct Conv2dFpropActivationIteratorOptimizedParams { // logical offset added to internal channel counter - units are elements, not bytes filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; } + +#if 0 + /// Prints internal state. + CUTLASS_HOST_DEVICE + void print() { + auto stride = layout.stride(); + printf( + "Conv2dFpropActivationIteratorOptimizedParams:\n" + " layout(w: %d, h: %d, n: %d)\n" + " inc_next[%ld, %ld, %ld]\n" + " filter_c_delta(%d) - PQ(%d)\n" + " pq_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n" + " q_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n", + stride[0], stride[1], stride[2], + inc_next[0], inc_next[1], inc_next[2], + filter_c_delta, + PQ, + pq_divmod.divisor, + pq_divmod.multiplier, + pq_divmod.shift_right, + q_divmod.divisor, + q_divmod.multiplier, + q_divmod.shift_right + ); + } +#endif }; /// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized @@ -324,6 +382,23 @@ struct Conv2dFpropFilterIteratorOptimizedParams filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; } + +#if 0 + /// Prints internal state. + CUTLASS_HOST_DEVICE + void print() { + auto stride = layout.stride(); + printf( + "Conv2dFpropFilterIteratorOptimizedParams:\n" + " layout[%d, %d, %d]\n" + " RS(%d), filter_c_delta(%d), inc_next(k: %ld, rs: %ld, c: %ld)\n", + stride[0], stride[1], stride[2], + RS, + filter_c_delta, + inc_next_k, inc_next_rs, inc_next_c + ); + } +#endif }; template @@ -382,6 +457,9 @@ struct Conv2dFpropFilterIteratorOptimizedParams +class TileIteratorStridedDgrad { +public: + using TileAccessIterator = TileAccessIterator_; + + using Shape = typename TileAccessIterator::Shape; + using Element = typename TileAccessIterator::Element; + using Layout = typename TileAccessIterator::Layout; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = typename TileAccessIterator::ThreadMap; + using AccessType = typename TileAccessIterator::AccessType; + using TensorRef = typename TileAccessIterator::TensorRef; + using Index = typename TileAccessIterator::Index; + using LongIndex = typename TileAccessIterator::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; + using Params = typename TileAccessIterator::Params; + static int const kConvDim = TileAccessIterator::kConvDim; + using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + +private: + + /// Internal state + TileAccessIterator tile_access_iterator_; + +public: + + /// Constructor + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + tile_access_iterator_(params, problem_size, ptr, thread_idx, start_r, start_s, threadblock_offset) { } + + CUTLASS_HOST_DEVICE + static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { + return TileAccessIterator::getParams(problem_size, layout); + } + + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + tile_access_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad &operator++() { + tile_access_iterator_.advance(); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad operator++(int) { + TileIteratorStridedDgrad self(*this); + operator++(); + return self; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.clear(); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[c + s * ThreadMap::Iterations::kContiguous], + tile_access_iterator_.get() + pointer_offset, + tile_access_iterator_.valid() + ); + + ++tile_access_iterator_; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + tile_access_iterator_.set_iteration_index(0); + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void advance() { + tile_access_iterator_.advance(); + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // dispatch to iterator implementation + return TileAccessIterator::can_implement(problem_size); + } +}; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace threadblock diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h index 1e3a5837..79a506b9 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h @@ -243,6 +243,7 @@ public: } }; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace threadblock @@ -250,5 +251,3 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - - diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h index f138ef59..48940686 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -196,7 +196,7 @@ private: CUTLASS_HOST_DEVICE TensorCoord at_(int offset_npq, int k) const { - // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // The subsequent fast_divmod() operations are equivalent to the following logical computation: // // // int npq = offset_npq; diff --git a/include/cutlass/conv/threadblock/conv3d_params.h b/include/cutlass/conv/threadblock/conv3d_params.h index c95b52d9..1df71f39 100644 --- a/include/cutlass/conv/threadblock/conv3d_params.h +++ b/include/cutlass/conv/threadblock/conv3d_params.h @@ -355,6 +355,145 @@ struct Conv3dDgradFilterIteratorOptimizedParams { } }; +/// Parameters object for Conv3d WGRAD OutputGradient iterator +struct Conv3dWgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + using LongIndex = typename Layout::LongIndex; + + Layout layout; + + int NZPQ; // precomputd product of N*Z*P*Q for clearing predicates + int ZPQ; // product of Z*P*Q + unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ + unsigned zpq_shr; // in device code. + + int PQ; // product of P*Q + unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ + unsigned pq_shr; // in device code. + + unsigned q_mul; // precomputed quantities for fast computation of div/% by Q + unsigned q_shr; // in device code. + + LongIndex offset_next_strided; // offset in units of bytes to next nzpq coordinate within tile + LongIndex offset_next_contiguous; // offset in units of bytes to next k coordinate within tile + LongIndex inc_next_nzpq; // offset in units of bytes to next nzpq position in subsequent tile + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): layout(layout) { + + TRACE_CONV_INITIALIZERS("conv3d_wgrad", "output_gradient", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + // Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8 + offset_next_strided = (threadmap_delta.strided() * layout.stride()[0]) + * element_size_bits / 8; + + offset_next_contiguous = (threadmap_delta.contiguous()) + * element_size_bits / 8; + + inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0]) + * element_size_bits / 8; + + // Precompute several quantities for fast modulo arithmetic. + NZPQ = problem_size.N * problem_size.Z * problem_size.P * problem_size.Q; + ZPQ = problem_size.Z * problem_size.P * problem_size.Q; + find_divisor(zpq_mul, zpq_shr, ZPQ); + + PQ = problem_size.P * problem_size.Q; + find_divisor(pq_mul, pq_shr, PQ); + + find_divisor(q_mul, q_shr, problem_size.Q); + + } +}; + +/// Parameters object for Conv3d WGRAD Activation Tile Access Iterator +struct Conv3dWgradActivationIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + + Layout layout; + + int RSC; // product of R*S*C + unsigned rsc_mul; // precomputed quantities for fast computation of div/% by RSC + unsigned rsc_shr; // in device code. + + int SC; // product of S*C + unsigned sc_mul; // precomputed quantities for fast computation of div/% by SC + unsigned sc_shr; // in device code. + + unsigned c_mul; // precomputed quantities for fast computation of div/% by C + unsigned c_shr; // in device code. + + int ZPQ; // product of Z*P*Q + unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ + unsigned zpq_shr; // in device code. + + int PQ; // product of P*Q + unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ + unsigned pq_shr; // in device code. + + unsigned q_mul; // precomputed quantities for fast computation of div/% by Q + unsigned q_shr; // in device code. + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv3dWgradActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dWgradActivationIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): layout(layout) { + + TRACE_CONV_INITIALIZERS("conv3d_wgrad", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + // Precompute several quantities for fast modulo arithmetic. + RSC = problem_size.R * problem_size.S * problem_size.C; + find_divisor(rsc_mul, rsc_shr, RSC); + + SC = problem_size.S * problem_size.C; + find_divisor(sc_mul, sc_shr, SC); + + find_divisor(c_mul, c_shr, problem_size.C); + + ZPQ = problem_size.Z * problem_size.P * problem_size.Q; + find_divisor(zpq_mul, zpq_shr, ZPQ); + + PQ = problem_size.P * problem_size.Q; + find_divisor(pq_mul, pq_shr, PQ); + + find_divisor(q_mul, q_shr, problem_size.Q); + + } +}; + } // namespace threadblock } // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h index 2835480d..37694adc 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h @@ -45,6 +45,7 @@ #include "cutlass/layout/matrix.h" #include "cutlass/conv/convolution.h" #include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,62 +87,28 @@ public: // Parameters structure // - struct Params { - - Layout layout; - - int RSC; // product of R*S*C - unsigned rsc_mul; // precomputed quantities for fast computation of div/% by RSC - unsigned rsc_shr; // in device code. - - int SC; // product of S*C - unsigned sc_mul; // precomputed quantities for fast computation of div/% by SC - unsigned sc_shr; // in device code. - - unsigned c_mul; // precomputed quantities for fast computation of div/% by C - unsigned c_shr; // in device code. - - int ZPQ; // product of Z*P*Q - unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ - unsigned zpq_shr; // in device code. - - int PQ; // product of P*Q - unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ - unsigned pq_shr; // in device code. - - unsigned q_mul; // precomputed quantities for fast computation of div/% by Q - unsigned q_shr; // in device code. - + struct Params : Conv3dWgradActivationIteratorOptimizedParams { // // Methods // CUTLASS_HOST_DEVICE - Params() { } + Params() {} CUTLASS_HOST_DEVICE - Params( - Conv3dProblemSize const &problem_size, - Layout const &layout - ): layout(layout) { + Params(Conv3dWgradActivationIteratorOptimizedParams const &base) + : Conv3dWgradActivationIteratorOptimizedParams(base) {} - // Precompute several quantities for fast modulo arithmetic. - RSC = problem_size.R * problem_size.S * problem_size.C; - find_divisor(rsc_mul, rsc_shr, RSC); - - SC = problem_size.S * problem_size.C; - find_divisor(sc_mul, sc_shr, SC); - - find_divisor(c_mul, c_shr, problem_size.C); - - ZPQ = problem_size.Z * problem_size.P * problem_size.Q; - find_divisor(zpq_mul, zpq_shr, ZPQ); - - PQ = problem_size.P * problem_size.Q; - find_divisor(pq_mul, pq_shr, PQ); - - find_divisor(q_mul, q_shr, problem_size.Q); - - } + CUTLASS_HOST_DEVICE + Params(Conv3dProblemSize const &problem_size, Layout const &layout) + : Conv3dWgradActivationIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} }; private: diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h index d3b356e0..13a60bf3 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -45,6 +45,7 @@ #include "cutlass/layout/matrix.h" #include "cutlass/conv/convolution.h" #include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,61 +87,29 @@ public: // Parameters structure // - struct Params { - - Layout layout; - - int NZPQ; // precomputd product of N*Z*P*Q for clearing predicates - int ZPQ; // product of Z*P*Q - unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ - unsigned zpq_shr; // in device code. - - int PQ; // product of P*Q - unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ - unsigned pq_shr; // in device code. - - unsigned q_mul; // precomputed quantities for fast computation of div/% by Q - unsigned q_shr; // in device code. - - LongIndex offset_next_strided; // offset in units of bytes to next nzpq coordinate within tile - LongIndex offset_next_contiguous; // offset in units of bytes to next k coordinate within tile - LongIndex inc_next_nzpq; // offset in units of bytes to next nzpq position in subsequent tile - + struct Params : Conv3dWgradOutputGradientIteratorOptimizedParams { // // Methods // + CUTLASS_HOST_DEVICE + Params() {} CUTLASS_HOST_DEVICE - Params() { } + Params(Conv3dWgradOutputGradientIteratorOptimizedParams const &base) + : Conv3dWgradOutputGradientIteratorOptimizedParams(base) {} CUTLASS_HOST_DEVICE - Params( - Conv3dProblemSize const &problem_size, - Layout const &layout - ): layout(layout) { - - // Incremental offsets in unites of bytes (number of elements) * sizeof_bits::value / 8 - offset_next_strided = (ThreadMap::Delta::kStrided * layout.stride()[0]) - * sizeof_bits::value / 8; - - offset_next_contiguous = (ThreadMap::Delta::kContiguous) - * sizeof_bits::value / 8; - - inc_next_nzpq = (Shape::kColumn * problem_size.split_k_slices * layout.stride()[0]) - * sizeof_bits::value / 8; - - // Precompute several quantities for fast modulo arithmetic. - NZPQ = problem_size.N * problem_size.Z * problem_size.P * problem_size.Q; - ZPQ = problem_size.Z * problem_size.P * problem_size.Q; - find_divisor(zpq_mul, zpq_shr, ZPQ); - - PQ = problem_size.P * problem_size.Q; - find_divisor(pq_mul, pq_shr, PQ); - - find_divisor(q_mul, q_shr, problem_size.Q); - - } - }; + Params(Conv3dProblemSize const &problem_size, Layout const &layout) + : Conv3dWgradOutputGradientIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} + }; private: diff --git a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h index aefdcd6d..85890b96 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h @@ -377,7 +377,7 @@ public: this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); diff --git a/include/cutlass/conv/threadblock/threadblock_swizzle.h b/include/cutlass/conv/threadblock/threadblock_swizzle.h new file mode 100644 index 00000000..6493ce8a --- /dev/null +++ b/include/cutlass/conv/threadblock/threadblock_swizzle.h @@ -0,0 +1,166 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements several possible threadblock-swizzling functions mapping blockIdx to + Convolution problems. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/platform/platform.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +CUTLASS_HOST_DEVICE +static int get_strided_dgrad_tile_m( + cutlass::conv::Conv2dProblemSize const &problem_size, + int tile_size_m) { + + // CTAs in M dimension per starting filter position + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, tile_size_m); + + // Inflate number of CTAs in M dimension to cover every strating filter position even those that + // may fall out of valid MMA (Dy * w) but are needed to apply epilogue (beta * Dx_source) + // and point-wise fusion + int tile_m = tile_m_per_filter * int(problem_size.stride().product()); + + // There is a possible performance optimization here that leads up to 2x speeds than the current + // CUTLASS strided dgrad performance for stride > filter, i.e., stride={2x2} and filter={1x1}) + // + // * Optimization * + // Only launch CTAs in M dimenstion which contribute to a row in Dx output + // + // + // * Constraints * + // (A) stride <= filter, for example, stride={2x2} and filter={3x3}: + // - (A.1): There are no constraints for this case and the optimization does + // affect this case functionality or performance. + // (B) stride > filter, for example, stride={2x2} and filter={1x1}: + // - (B.1): Dx output tensor should be zero initialized + // - (B.2): The kernel epilogue cannot apply beta. Thus, beta should be zero + + return tile_m; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Threadblock swizzling function for strided dgrad convolution +struct StridedDgradHorizontalThreadblockSwizzle : + public gemm::threadblock::GemmHorizontalThreadblockSwizzle { + + using Base = gemm::threadblock::GemmHorizontalThreadblockSwizzle; + + CUTLASS_HOST_DEVICE + StridedDgradHorizontalThreadblockSwizzle() { } + + /// Returns the shape of the problem in units of logical tiles + /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) + CUTLASS_HOST_DEVICE + gemm::GemmCoord get_tiled_shape( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + gemm::GemmCoord tile_size, + int split_k_slices) const { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + // compute number of tiles in m dimension + int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); + + // compute number of tiles in n dimenstion + int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); + + return gemm::GemmCoord( + tile_m, + tile_n, + split_k_slices); + } + + /// Returns the shape of the problem in units of logical tiles + /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) + private: + using Base::get_tiled_shape; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Threadblock swizzling function for strided dgrad convolution +template +struct StridedDgradIdentityThreadblockSwizzle : + public gemm::threadblock::GemmIdentityThreadblockSwizzle { + + using Base = gemm::threadblock::GemmIdentityThreadblockSwizzle; + + CUTLASS_HOST_DEVICE + StridedDgradIdentityThreadblockSwizzle() { } + + /// Returns the shape of the problem in units of logical tiles + /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) + CUTLASS_HOST_DEVICE + gemm::GemmCoord get_tiled_shape( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + gemm::GemmCoord tile_size, + int split_k_slices) const { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + // compute number of tiles in m dimension + int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); + + // compute number of tiles in n dimenstion + int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); + + return gemm::GemmCoord( + tile_m, + tile_n, + split_k_slices); + } + + + /// Returns the shape of the problem in units of logical tiles + /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) + private: + using Base::get_tiled_shape; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index 7c7aaf3a..c4db95b6 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -412,39 +412,61 @@ Coord operator/(Coord coord, Index s) { //////////////////////////////////////////////////////////////////////////////////////////////////// /// Helper to make a 2-element coordinate +template CUTLASS_HOST_DEVICE -Coord<1> make_Coord(int _0) { - int values[1] = {_0}; - return Coord<1>(values); +Coord<1, T> make_Coord(T _0) { + T values[1] = {_0}; + return Coord<1, T>(values); } /// Helper to make a 2-element coordinate +template CUTLASS_HOST_DEVICE -Coord<2> make_Coord(int _0, int _1) { - int values[2] = {_0, _1}; - return Coord<2>(values); +Coord<2, T> make_Coord(T _0, T _1) { + T values[2] = {_0, _1}; + return Coord<2, T>(values); } /// Helper to make a 3-element coordinate +template CUTLASS_HOST_DEVICE -Coord<3> make_Coord(int _0, int _1, int _2) { - int values[3] = {_0, _1, _2}; - return Coord<3>(values); +Coord<3, T> make_Coord(T _0, T _1, T _2) { + T values[3] = {_0, _1, _2}; + return Coord<3, T>(values); } /// Helper to make a 4-element coordinate +template CUTLASS_HOST_DEVICE -Coord<4> make_Coord(int _0, int _1, int _2, int _3) { - int values[4] = {_0, _1, _2, _3}; - return Coord<4>(values); +Coord<4, T> make_Coord(T _0, T _1, T _2, T _3) { + T values[4] = {_0, _1, _2, _3}; + return Coord<4, T>(values); } /// Helper to make a 5-element coordinate +template CUTLASS_HOST_DEVICE -Coord<5> make_Coord(int _0, int _1, int _2, int _3, int _4) { - int values[5] = {_0, _1, _2, _3, _4}; - return Coord<5>(values); +Coord<5, T> make_Coord(T _0, T _1, T _2, T _3, T _4) { + T values[5] = {_0, _1, _2, _3, _4}; + return Coord<5, T>(values); } + +/// Helper to make a 1-element coordinate +template +CUTLASS_HOST_DEVICE +Coordmake_Coord_with_padding(T _0) { + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = N - 1; i > 0; --i) { + coord[i] = 0; + } + + coord[0] = _0; + + return coord; +} + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index b25806a3..86383a69 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -34,6 +34,8 @@ #include "cutlass/array.h" #include "cutlass/coord.h" #include "cutlass/numeric_types.h" +#include "cutlass/matrix.h" +#include "cutlass/quaternion.h" #include "cutlass/matrix_shape.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/tensor_view.h" @@ -150,6 +152,45 @@ std::ostream & operator<<(std::ostream &out, MatrixShape const &mat return out; } + +/// Prints matrix to ostream +template +std::ostream & operator<<(std::ostream &out, Matrix const &rhs) { + + for (int i = 0; i < Rows; ++i) { + for (int j = 0; j < Columns; ++j) { + ScalarIO element(rhs.at(i, j)); + out << (j ? ", " : "") << element; + } + out << "\\n"; + } + + return out; +} + +template +std::ostream &operator<<(std::ostream &out, Quaternion const &rhs) { + + out << ScalarIO(rhs.w()) << " "; + if (rhs.x() >= 0) { + out << "+"; + } + + out << ScalarIO(rhs.x()) << "*i "; + if (rhs.y() >= 0) { + out << "+"; + } + + out << ScalarIO(rhs.y()) << "*j "; + if (rhs.z() >= 0) { + out << "+"; + } + + out << ScalarIO(rhs.z()) << "*k"; + + return out; +} + /////////////////////////////////////////////////////////////////////////////////////////////////// // stream operators for cutlass::gemm namespace // /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index 5a703980..1c34fc61 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -141,26 +141,6 @@ static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2; static const int NUM_THREADS_PER_QUAD = 4; static const int NUM_THREADS_PER_QUAD_PAIR = NUM_THREADS_PER_QUAD * 2; -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) - -/// Computes laneId within a warp -CUTLASS_DEVICE -int LaneId() { - int ret; - asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : ); - return ret; -} - -/// Computes SM number the thread is running on -CUTLASS_DEVICE -int SmId() { - int ret; - asm ("mov.u32 %0, %%smid;" : "=r"(ret) : ); - return ret; -} - -#endif - //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index bcfed6ca..e0da4eed 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -180,6 +180,7 @@ struct GELU > { // GELU operator implemented using the Taylor series approximation template struct GELU_taylor { + static const bool kIsHeavy=true; CUTLASS_HOST_DEVICE T operator()(T const &z) const { @@ -193,6 +194,7 @@ struct GELU_taylor { template struct GELU_taylor > { + static const bool kIsHeavy=true; CUTLASS_HOST_DEVICE Array operator()(Array const &rhs) const { Array y; @@ -250,4 +252,3 @@ struct dGELU > { } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/epilogue/thread/conversion_op.h b/include/cutlass/epilogue/thread/conversion_op.h index 7cdf6cb0..c604d45a 100644 --- a/include/cutlass/epilogue/thread/conversion_op.h +++ b/include/cutlass/epilogue/thread/conversion_op.h @@ -65,6 +65,8 @@ public: static FloatRoundStyle const kRound = Round; + static bool const kIsHeavy = false; + /// Host-constructable parameters structure struct Params { diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index 0be9fc47..70aac237 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -49,7 +49,9 @@ namespace thread { /// template < typename ElementOutput_, ///< Data type used to load and store tensors - int Count, ///< Number of elements computed per operation + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling @@ -146,6 +148,8 @@ public: if (Scale == ScaleType::OnlyAlphaScaling) return false; + if (Scale == ScaleType::Nothing) return false; + return beta_ != ElementCompute(0); } @@ -167,11 +171,17 @@ public: NumericArrayConverter source_converter; NumericArrayConverter accumulator_converter; - ComputeFragment converted_source = source_converter(source); + // Convert to destination numeric type + NumericArrayConverter destination_converter; + ComputeFragment converted_accumulator = accumulator_converter(accumulator); - // Perform binary operations + if (Scale == ScaleType::Nothing) + return destination_converter(converted_accumulator); + ComputeFragment converted_source = source_converter(source); + + // Perform binary operations ComputeFragment intermediate; multiplies mul_add_source; @@ -180,13 +190,10 @@ public: if (Scale == ScaleType::NoBetaScaling) intermediate = converted_source; else - intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X - // Convert to destination numeric type - NumericArrayConverter destination_converter; - return destination_converter(intermediate); } @@ -198,17 +205,20 @@ public: // Convert source to interal compute numeric type NumericArrayConverter accumulator_converter; + // Convert to destination numeric type + NumericArrayConverter destination_converter; + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + if (Scale == ScaleType::Nothing) + return destination_converter(converted_accumulator); + // Perform binary operations ComputeFragment intermediate; multiplies mul_accumulator; intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum - // Convert to destination numeric type - NumericArrayConverter destination_converter; - return destination_converter(intermediate); } }; diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h new file mode 100644 index 00000000..4e17fc4e --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -0,0 +1,251 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/epilogue/thread/activation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template < + typename ElementC_, + typename ElementAccumulator_, + typename ElementCompute_, + typename ElementZ_, + typename ElementT_, + int ElementsPerAccess, + typename ElementwiseOp_ = Identity, + typename BinaryOp_ = plus +> +class LinearCombinationBiasElementwise { +public: + + using ElementOutput = ElementC_; + using ElementC = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + + using ElementwiseOp = ElementwiseOp_; + using BinaryOp = BinaryOp_; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; + + using FragmentOutput = FragmentZ; + + static bool const kIsHeavy = ElementwiseOp::kIsHeavy; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = true; + + /// If true, the 'T' tensor is stored + static bool const kStoreT = true; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha + ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { + + } + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + bool skip_elementwise_; + +public: + + // + // Methods + // + + /// Constructor from Params + CUTLASS_HOST_DEVICE + LinearCombinationBiasElementwise(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + skip_elementwise_ = false; + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_ != ElementCompute(0); + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + skip_elementwise_ = true; + } + } + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()( + FragmentZ &frag_Z, + FragmentT &frag_T, + FragmentAccumulator const &AB, + FragmentC const &frag_C, + FragmentCompute const &V) const { + + ElementwiseOp elementwise_op; + BinaryOp binary_op; + + FragmentCompute tmp_Accum = NumericArrayConverter()(AB); + FragmentCompute tmp_C = NumericArrayConverter()(frag_C); + FragmentCompute result_Z; + FragmentCompute result_T; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]); + result_Z[i] = z; + result_T[i] = skip_elementwise_ ? z : elementwise_op(z); + } + + NumericArrayConverter convert_z; + frag_Z = convert_z(result_Z); + + NumericArrayConverter convert_t; + frag_T = convert_t(result_T); + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()( + FragmentZ &frag_Z, + FragmentT &frag_T, + FragmentAccumulator const &AB, + FragmentCompute const &V) const { + + ElementwiseOp elementwise_op; + BinaryOp binary_op; + + FragmentCompute tmp_Accum = NumericArrayConverter()(AB); + FragmentCompute result_Z; + FragmentCompute result_T; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]); + result_Z[i] = z; + result_T[i] = skip_elementwise_ ? z : elementwise_op(z); + } + + NumericArrayConverter convert_z; + frag_Z = convert_z(result_Z); + + NumericArrayConverter convert_t; + frag_T = convert_t(result_T); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_relu.h b/include/cutlass/epilogue/thread/linear_combination_bias_relu.h index 8c898f90..b4145917 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_relu.h @@ -28,6 +28,8 @@ #pragma once +#include + #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/array.h" @@ -41,6 +43,146 @@ namespace cutlass { namespace epilogue { namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayMaximum { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &lhs, + Array const &rhs) const { + + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = fmax(lhs[i], rhs[i]); + } + + return result; + } +}; + +template +struct ArrayMaximum { + + CUTLASS_DEVICE + Array operator()( + Array const &lhs, + Array const &rhs) const { + + Array result; + + #if __CUDA_ARCH__ >= 800 + int const kVectorCount = ElementsPerAccess / 2; + + + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data()); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(rhs.raw_data()); + __half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); + } + + #else + __half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data()); + __half const *rhs_ptr = reinterpret_cast<__half const *>(rhs.raw_data()); + __half *res_ptr = reinterpret_cast<__half *>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]); + } + + #endif + + return result; + } + + CUTLASS_DEVICE + Array operator()( + Array const &lhs, + half_t const &rhs) const { + + Array result; + + #if __CUDA_ARCH__ >= 800 + int const kVectorCount = ElementsPerAccess / 2; + + + __half rhs_raw = reinterpret_cast<__half const &>(rhs); + __half2 rhs_pair = __half2half2(rhs_raw); + + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data()); + __half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); + } + + #else + + __half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data()); + __half const rhs_raw = reinterpret_cast<__half const &>(rhs); + __half *res_ptr = reinterpret_cast<__half *>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]); + } + + #endif + + return result; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ReluConditional { + + CUTLASS_HOST_DEVICE + void operator()( + bool conditional[], + Array const &fragment, + Element threshold) const { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + conditional[i] = !(fragment[i] < threshold); + } + } +}; + +template +struct ReluConditional { + + CUTLASS_DEVICE + void operator()( + bool conditional[], + Array const &fragment, + half_t threshold) const { + + __half y = reinterpret_cast<__half const &>(threshold); + __half const *x = reinterpret_cast<__half const *>(fragment.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + conditional[i] = !__hlt(x[i], y); + } + } +}; + +} // namespace detail + ///////////////////////////////////////////////////////////////////////////////////////////////// /// This is a partial specialization for fused Bias and ReLU. It supports the option of packing @@ -94,8 +236,11 @@ public: ElementCompute beta; ///< scales source tensor ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory - ElementCompute threshold; ///< ReLu threshold + ElementZ threshold; ///< ReLu threshold + // + // Methods + // // // Methods // @@ -112,16 +257,19 @@ public: Params( ElementCompute alpha, ElementCompute beta, - ElementCompute threshold = ElementCompute() + ElementCompute threshold_ = ElementCompute() ): - alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr), threshold(threshold) { + alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + NumericConverter convert_threshold; + + threshold = convert_threshold(threshold_); } CUTLASS_HOST_DEVICE Params( ElementCompute alpha - ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr), threshold(threshold) { + ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr), threshold(ElementZ()) { } @@ -129,17 +277,20 @@ public: Params( ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr, - ElementCompute threshold = ElementCompute() - ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), threshold(threshold) { + ElementCompute threshold_ = ElementCompute() + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + NumericConverter convert_threshold; + + threshold = convert_threshold(threshold_); } CUTLASS_HOST_DEVICE Params( ElementCompute const *alpha_ptr - ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr), threshold(threshold) { - + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr), threshold(ElementZ()) { } + }; private: @@ -150,7 +301,7 @@ private: ElementCompute alpha_; ElementCompute beta_; - ElementCompute threshold_; + ElementZ threshold_; public: @@ -179,6 +330,12 @@ public: if (k_partition) { beta_ = ElementCompute(1); } + + if (k_partition != k_partition_count - 1) { + // set to NaN to make ReLU no-op for all except last k partitions + int64_t allones = -1; + threshold_ = reinterpret_cast(allones); + } } /// Applies the operation when is_source_needed() is true @@ -201,18 +358,27 @@ public: CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { - ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]); - bool condition = !(z < threshold_); - z = fmax(z, threshold_); + ElementCompute z = alpha_ * tmp_Accum[i]; + z += beta_ * tmp_C[i]; + z = binary_op(z, V[i]); result_Z[i] = z; - conditions[i] = condition; } NumericArrayConverter convert_z; frag_Z = convert_z(result_Z); + // + // Compute condition + // + + detail::ReluConditional relu_conditional; + relu_conditional(conditions, frag_Z, threshold_); + + detail::ArrayMaximum maximum_op; + frag_Z = maximum_op(frag_Z, threshold_); + if (kStoreT) { PackPredicates pack_predicates; frag_T = pack_predicates(conditions); @@ -238,17 +404,29 @@ public: CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]); - - bool condition = !(z < threshold_); - z = fmax(z, threshold_); - result_Z[i] = z; - conditions[i] = condition; } NumericArrayConverter convert_z; frag_Z = convert_z(result_Z); + // + // Compute condition + // + + detail::ReluConditional relu_conditional; + relu_conditional(conditions, frag_Z, threshold_); + + detail::ArrayMaximum maximum_op; + frag_Z = maximum_op(frag_Z, threshold_); + + // + // Compute conditions + // + + // + // Store + // if (kStoreT) { PackPredicates pack_predicates; frag_T = pack_predicates(conditions); diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index e1bf10bb..2cd150e0 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -43,6 +43,17 @@ namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +/// Single source of truth for whether to unroll for `LinearCombinationClamp()` +constexpr bool LinearCombinationClampIsHeavy() { + return false; +} + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Applies a linear combination operator to an array of elements then clamps the output before /// converting to the output element type. /// @@ -51,6 +62,8 @@ namespace thread { template < typename ElementOutput_, ///< Data type used to load and store tensors int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling @@ -71,6 +84,8 @@ public: static FloatRoundStyle const kRound = Round; + static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy(); + /// Host-constructable parameters structure struct Params { @@ -282,6 +297,8 @@ public: static FloatRoundStyle const kRound = Round; + static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy(); + /// Host-constructable parameters structure struct Params { @@ -396,10 +413,9 @@ public: // Convert floats back to INT FragmentAccumulator scaled_accumulator; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - scaled_accumulator[i] = __float2int_rn(intermediate[i]); - } + NumericArrayConverter compute_converter; + + scaled_accumulator = compute_converter(intermediate); // Convert to destination numeric type NumericArrayConverter destination_converter; @@ -427,10 +443,9 @@ public: // Convert floats back to INT FragmentAccumulator scaled_accumulator; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - scaled_accumulator[i] = __float2int_rn(intermediate[i]); - } + NumericArrayConverter compute_converter; + + scaled_accumulator = compute_converter(intermediate); // Convert to destination numeric type NumericArrayConverter destination_converter; @@ -487,6 +502,8 @@ class FastLinearCombinationClamp { static FloatRoundStyle const kRound = Round; + static bool const kIsHeavy = false; + /// Host-constructable parameters structure struct Params { /// scales accumulators diff --git a/include/cutlass/epilogue/thread/linear_combination_dgelu.h b/include/cutlass/epilogue/thread/linear_combination_dgelu.h new file mode 100644 index 00000000..faebcc68 --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_dgelu.h @@ -0,0 +1,244 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Functor performing linear combination followed by dGelu operation +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/constants.h" +#include "cutlass/fast_math.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementCompute_, ///< Data type returned by this functor + typename ElementAccumulator_, ///< Data type of accumulators + typename ElementSource_, ///< Data type of source tensor + typename ElementTensor_, ///< Data type of additional tensor + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LinearCombinationDGelu { +public: + + using ElementOutput = ElementSource_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ElementSource = ElementSource_; + using ElementTensor = ElementTensor_; + + static bool const kIsHeavy = true; + + static int const kCount = Count; + + using FragmentCompute = Array; + using FragmentAccumulator = Array; + using FragmentSource = Array; + using FragmentTensor = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute threshold; ///< minimum value that is output + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + threshold(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta, + ElementCompute threshold = ElementCompute(0) + ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr, + ElementCompute threshold = ElementCompute(0) + ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + ElementCompute threshold_; + bool participates_in_reduction_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationDGelu(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + threshold_ = params.threshold; + participates_in_reduction_ = true; + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_ != ElementCompute(0); + } + + /// Returns true if the threadblock computes the reduction + CUTLASS_HOST_DEVICE + bool participates_in_reduction() const { + return participates_in_reduction_; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + // set to NaN to make ReLU no-op for all except last k partitions + int64_t allones = -1; + threshold_ = reinterpret_cast(allones); + // Avoid computing the reduction if this isn't the final Split-K slice + participates_in_reduction_ = false; + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentCompute operator()( + FragmentAccumulator const &accumulator, + FragmentSource const &source, + FragmentTensor const &tensor) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + + dGELU gelu_op; + + // dGelu + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i])); + } + + return intermediate; + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentCompute operator()( + FragmentAccumulator const &accumulator, + FragmentTensor const &tensor) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_accumulator; + + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + dGELU gelu_op; + + // dGelu with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i])); + } + + return intermediate; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_drelu.h b/include/cutlass/epilogue/thread/linear_combination_drelu.h new file mode 100644 index 00000000..a69d5363 --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_drelu.h @@ -0,0 +1,446 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with a maximum operation used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementCompute_, ///< Data type returned by this functor + typename ElementAccumulator_, ///< Data type of accumulators + typename ElementSource_, ///< Data type of source tensor + typename ElementTensor_, ///< Data type of additional tensor + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LinearCombinationDRelu { +public: + + using ElementOutput = ElementSource_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ElementSource = ElementSource_; + using ElementTensor = ElementTensor_; + + static int const kCount = Count; + + using FragmentCompute = Array; + using FragmentAccumulator = Array; + using FragmentSource = Array; + using FragmentTensor = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute threshold; ///< minimum value that is output + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + threshold(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta, + ElementCompute threshold = ElementCompute(0) + ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr, + ElementCompute threshold = ElementCompute(0) + ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + ElementTensor threshold_; + bool participates_in_reduction_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationDRelu(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + threshold_ = ElementTensor(params.threshold); + participates_in_reduction_ = true; + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_ != ElementCompute(0); + } + + /// Returns true if the threadblock computes the reduction + CUTLASS_HOST_DEVICE + bool participates_in_reduction() const { + return participates_in_reduction_; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + // set to NaN to make ReLU no-op for all except last k partitions + int64_t allones = -1; + threshold_ = reinterpret_cast(allones); + participates_in_reduction_ = false; + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentCompute operator()( + FragmentAccumulator const &accumulator, + FragmentSource const &source, + FragmentTensor const &tensor) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + + // dReLU = (cond ? dy : 0) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + ElementTensor cond = tensor[i]; + if (cond <= threshold_) { + intermediate[i] = ElementCompute(); + } + } + + return intermediate; + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentCompute operator()( + FragmentAccumulator const &accumulator, + FragmentTensor const &tensor) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_accumulator; + + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + // dReLU = (cond ? dy : 0) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + ElementTensor cond = tensor[i]; + if (cond <= threshold_) { + intermediate[i] = ElementCompute(); + } + } + + return intermediate; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementCompute_, ///< Data type returned by this functor + typename ElementAccumulator_, ///< Data type of accumulators + typename ElementSource_, ///< Data type of source tensor + int Count, ///< Number of elements computed per operation + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LinearCombinationDReluConditionalBits { +public: + + using ElementOutput = ElementSource_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ElementSource = ElementSource_; + using ElementTensor = uint1b_t; + + static bool const kIsHeavy = false; + + static int const kCount = Count; + + using FragmentCompute = Array; + using FragmentAccumulator = Array; + using FragmentSource = Array; + using FragmentTensor = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + FragmentTensor predicate_mask_; + bool participates_in_reduction_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationDReluConditionalBits(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + participates_in_reduction_ = true; + predicate_mask_.clear(); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_ != ElementCompute(0); + } + + /// Returns true if the threadblock computes the reduction + CUTLASS_HOST_DEVICE + bool participates_in_reduction() const { + return participates_in_reduction_; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + predicate_mask_.clear(); + + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + // Avoid computing the reduction if this isn't the final Split-K slice + participates_in_reduction_ = false; + + bit_not not_op; + predicate_mask_ = not_op(predicate_mask_); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_DEVICE + FragmentCompute operator()( + FragmentAccumulator const &accumulator, + FragmentSource const &source, + FragmentTensor const &tensor) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + + bit_or or_op; + + FragmentTensor predicates = or_op(tensor, predicate_mask_); + + // Obtain from packed bits + bool conditions[kCount]; + UnpackPredicates unpack_predicates; + + unpack_predicates(conditions, predicates); + + // dReLU = (cond ? dy : 0) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + if (!conditions[i]) { + intermediate[i] = ElementCompute(); + } + } + + return intermediate; + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentCompute operator()( + FragmentAccumulator const &accumulator, + FragmentTensor const &tensor) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_accumulator; + + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + bit_or or_op; + + FragmentTensor predicates = or_op(tensor, predicate_mask_); + + // Obtain from packed bits + bool conditions[kCount]; + UnpackPredicates unpack_predicates; + + unpack_predicates(conditions, predicates); + + // dReLU = (cond ? dy : 0) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + if (!conditions[i]) { + intermediate[i] = ElementCompute(); + } + } + + return intermediate; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_gelu.h b/include/cutlass/epilogue/thread/linear_combination_gelu.h index ebb08056..e7905719 100644 --- a/include/cutlass/epilogue/thread/linear_combination_gelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_gelu.h @@ -51,6 +51,8 @@ namespace thread { template < typename ElementOutput_, ///< Data type used to load and store tensors int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination FloatRoundStyle Round = FloatRoundStyle::round_to_nearest @@ -62,6 +64,8 @@ public: using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; + static bool const kIsHeavy = true; + static int const kCount = Count; using FragmentOutput = Array; @@ -134,10 +138,11 @@ public: /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE void set_k_partition(int k_partition, int k_partition_count) { - CUTLASS_UNUSED(k_partition_count); if (k_partition) { beta_ = ElementCompute(1); } + + CUTLASS_UNUSED(k_partition_count); } /// Computes: D = gelu( alpha * accumulator + beta * source ) diff --git a/include/cutlass/epilogue/thread/linear_combination_planar_complex.h b/include/cutlass/epilogue/thread/linear_combination_planar_complex.h index 8ecaab65..580ebc13 100644 --- a/include/cutlass/epilogue/thread/linear_combination_planar_complex.h +++ b/include/cutlass/epilogue/thread/linear_combination_planar_complex.h @@ -52,6 +52,8 @@ namespace thread { template < typename ElementOutput_, ///< Data type used to load and store tensors int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination FloatRoundStyle Round = FloatRoundStyle::round_to_nearest diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index 2f40cf18..a7674a36 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -45,6 +45,17 @@ namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +/// Single source of truth for whether to unroll for `LinearCombinationClamp()` +constexpr bool LinearCombinationReluIsHeavy() { + return false; +} + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Applies a linear combination operator to an array of elements. /// /// D = alpha * accumulator + beta * source + uniform @@ -52,6 +63,8 @@ namespace thread { template < typename ElementOutput_, ///< Data type used to load and store tensors int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling @@ -72,6 +85,8 @@ public: static FloatRoundStyle const kRound = Round; + static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy(); + /// Host-constructable parameters structure struct Params { @@ -244,6 +259,8 @@ public: using ElementAccumulator = int; using ElementCompute = float; + static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy(); + static int const kCount = Count; using FragmentOutput = Array; @@ -357,10 +374,10 @@ public: ReLu relu; if (Scale == ScaleType::NoBetaScaling) - intermediate = converted_source; + intermediate = converted_source; else - intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform - + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X // Compute threshold optionally @@ -378,10 +395,9 @@ public: // Convert floats back to INT FragmentAccumulator scaled_accumulator; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - scaled_accumulator[i] = __float2int_rn(intermediate[i]); - } + NumericArrayConverter compute_converter; + + scaled_accumulator = compute_converter(intermediate); // Convert to destination numeric type NumericArrayConverter @@ -416,14 +432,6 @@ public: // Compute threshold optionally intermediate = relu(threshold_, intermediate); - // Convert floats back to INT - FragmentAccumulator scaled_accumulator; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - scaled_accumulator[i] = __float2int_rn(intermediate[i]); - } - if (platform::is_same::value || platform::is_same::value || platform::is_same::value || @@ -436,10 +444,9 @@ public: // Convert floats back to INT FragmentAccumulator scaled_accumulator; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - scaled_accumulator[i] = __float2int_rn(intermediate[i]); - } + NumericArrayConverter compute_converter; + + scaled_accumulator = compute_converter(intermediate); // Convert to destination numeric type NumericArrayConverter diff --git a/include/cutlass/epilogue/thread/linear_combination_sigmoid.h b/include/cutlass/epilogue/thread/linear_combination_sigmoid.h index cea2d7a8..70e99a33 100644 --- a/include/cutlass/epilogue/thread/linear_combination_sigmoid.h +++ b/include/cutlass/epilogue/thread/linear_combination_sigmoid.h @@ -51,6 +51,8 @@ namespace thread { template < typename ElementOutput_, ///< Data type used to load and store tensors int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination FloatRoundStyle Round = FloatRoundStyle::round_to_nearest diff --git a/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h new file mode 100644 index 00000000..ef9c24cc --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h @@ -0,0 +1,228 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Functor performing linear combination with elementwise +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/constants.h" +#include "cutlass/fast_math.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementCompute_, ///< Data type returned by this functor + typename ElementAccumulator_, ///< Data type of accumulators + typename ElementSource_, ///< Data type of source tensor + typename ElementTensor_, ///< Data type of additional tensor + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LinearCombinationWithElementwise { +public: + + using ElementOutput = ElementSource_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ElementSource = ElementSource_; + using ElementTensor = ElementTensor_; + + static bool const kIsHeavy = true; + + static int const kCount = Count; + + using FragmentCompute = Array; + using FragmentAccumulator = Array; + using FragmentSource = Array; + using FragmentTensor = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute threshold; ///< minimum value that is output + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + threshold(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta, + ElementCompute threshold = ElementCompute(0) + ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr, + ElementCompute threshold = ElementCompute(0) + ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + ElementCompute threshold_; + bool participates_in_reduction_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationWithElementwise(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + threshold_ = params.threshold; + participates_in_reduction_ = true; + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_ != ElementCompute(0); + } + + /// Returns true if the threadblock computes the reduction + CUTLASS_HOST_DEVICE + bool participates_in_reduction() const { + return participates_in_reduction_; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + // set to NaN to make ReLU no-op for all except last k partitions + int64_t allones = -1; + threshold_ = reinterpret_cast(allones); + // Avoid computing the reduction if this isn't the final Split-K slice + participates_in_reduction_ = false; + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentCompute operator()( + FragmentAccumulator const &accumulator, + FragmentSource const &source, + FragmentTensor const &tensor) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + + return intermediate; + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentCompute operator()( + FragmentAccumulator const &accumulator, + FragmentTensor const &tensor) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_accumulator; + + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + return intermediate; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/scale_type.h b/include/cutlass/epilogue/thread/scale_type.h index 200db83a..58515fbc 100644 --- a/include/cutlass/epilogue/thread/scale_type.h +++ b/include/cutlass/epilogue/thread/scale_type.h @@ -41,9 +41,10 @@ namespace thread { /// Specifies internal data type for computation struct ScaleType { enum Kind { - Default, // alpha x C + beta x D - NoBetaScaling, // alpha x C + D - OnlyAlphaScaling // alpha x C + Default, // alpha x C + beta x D + NoBetaScaling, // alpha x C + D + OnlyAlphaScaling, // alpha x C + Nothing // C }; }; diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h index 3420cec7..d2e17421 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h @@ -44,7 +44,6 @@ #include "cutlass/epilogue/thread/linear_combination_gelu.h" #include "cutlass/epilogue/thread/linear_combination_sigmoid.h" #include "cutlass/epilogue/thread/linear_combination_planar_complex.h" - #include "cutlass/epilogue/thread/conversion_op.h" #include "cutlass/epilogue/thread/reduction_op.h" @@ -55,6 +54,8 @@ #include "cutlass/epilogue/threadblock/default_thread_map_simt.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" #include "cutlass/epilogue/threadblock/shared_load_iterator.h" #include "cutlass/epilogue/threadblock/epilogue.h" @@ -144,6 +145,164 @@ struct DefaultEpilogueSimt { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines sensible defaults for epilogues for SimtOps. +template < + typename Shape_, + typename WarpMmaSimt_, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpilogueSimtStridedDgrad { + + using Shape = Shape_; + using WarpMmaSimt = WarpMmaSimt_; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK; + + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaSimt::LayoutC; + using ElementAccumulator = typename WarpMmaSimt::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt< + Shape, + typename WarpMmaSimt::Shape, + typename WarpMmaSimt::Policy, + kPartitionsK, + ElementOutput, + kElementsPerAccess + >::Type; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< + OutputTileThreadMap, + ElementOutput + >; + + using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< + typename WarpMmaSimt::Shape, + typename WarpMmaSimt::ThreadMma, + layout::RowMajor, + typename WarpMmaSimt::Policy + >; + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt< + typename WarpMmaSimt::Shape, + typename WarpMmaSimt::ThreadMma, + ElementAccumulator, + layout::RowMajor, + typename WarpMmaSimt::Policy + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< + typename OutputTileThreadMap::CompactedThreadMap, + ElementAccumulator + >; + + /// Hard-coded padding elements added + using Padding = typename WarpTileIterator::Padding; + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::Epilogue< + Shape, + WarpMmaSimt, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + WarpTileIterator, + SharedLoadIterator, + OutputOp, + Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for SimtOps. +template < + int Rank, + typename Shape_, + typename WarpMmaSimt_, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpilogueSimtAffineRankN { + + using Shape = Shape_; + using WarpMmaSimt = WarpMmaSimt_; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK; + + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaSimt::LayoutC; + using ElementAccumulator = typename WarpMmaSimt::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt< + Shape, + typename WarpMmaSimt::Shape, + typename WarpMmaSimt::Policy, + kPartitionsK, + ElementOutput, + kElementsPerAccess + >::Type; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN< + OutputTileThreadMap, + ElementOutput, + Rank + >; + + using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< + typename WarpMmaSimt::Shape, + typename WarpMmaSimt::ThreadMma, + layout::RowMajor, + typename WarpMmaSimt::Policy + >; + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt< + typename WarpMmaSimt::Shape, + typename WarpMmaSimt::ThreadMma, + ElementAccumulator, + layout::RowMajor, + typename WarpMmaSimt::Policy + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< + typename OutputTileThreadMap::CompactedThreadMap, + ElementAccumulator + >; + + /// Hard-coded padding elements added + using Padding = typename WarpTileIterator::Padding; + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::Epilogue< + Shape, + WarpMmaSimt, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + WarpTileIterator, + SharedLoadIterator, + OutputOp, + Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace threadblock } // namespace epilogue } // namespace cutlass diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index 5538fa8a..79ecbdf2 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -56,6 +56,8 @@ #include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" #include "cutlass/epilogue/threadblock/shared_load_iterator.h" #include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" @@ -364,6 +366,188 @@ struct DefaultEpilogueTensorOp { //////////////////////////////////////////////////////////////////////////////// +/// Defines sensible defaults for epilogues for TensorOps. +template < + typename Shape_, + typename WarpMmaTensorOp_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpilogueTensorOpStridedDgrad { + + using Shape = Shape_; + using WarpMmaTensorOp = WarpMmaTensorOp_; + static int const kPartitionsK = PartitionsK; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaTensorOp::LayoutC; + using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess + >::Type; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< + OutputTileThreadMap, + ElementOutput + >; + + using AccumulatorFragmentIterator = typename std::conditional::value, + cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>, + cutlass::epilogue::warp::FragmentIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC> >::type; + + /// Support several implementations depending on structure of epilogue + using DefaultIterators = detail::DefaultIteratorsTensorOp< + ElementOutput, + ElementAccumulator, + kElementsPerAccess, + Shape, + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename OutputTileThreadMap::CompactedThreadMap + >; + + using WarpTileIterator = typename DefaultIterators::WarpTileIterator; + using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; + + /// Hard-coded padding elements added + using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; + + static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::Epilogue< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + WarpTileIterator, + SharedLoadIterator, + OutputOp, + Padding, + kFragmentsPerIteration + >; +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + int Rank, + typename Shape_, + typename WarpMmaTensorOp_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpilogueTensorOpAffineRankN { + + using Shape = Shape_; + using WarpMmaTensorOp = WarpMmaTensorOp_; + static int const kPartitionsK = PartitionsK; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaTensorOp::LayoutC; + using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess + >::Type; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN< + OutputTileThreadMap, + ElementOutput, + Rank + >; + + // Map to the row major iterator since the iterator selection for affineN is the same. + using AccumulatorFragmentIterator = typename std::conditional::value, + cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + layout::RowMajor>, + cutlass::epilogue::warp::FragmentIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + layout::RowMajor> >::type; + + /// Support several implementations depending on structure of epilogue + using DefaultIterators = detail::DefaultIteratorsTensorOp< + ElementOutput, + ElementAccumulator, + kElementsPerAccess, + Shape, + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename OutputTileThreadMap::CompactedThreadMap + >; + + using WarpTileIterator = typename DefaultIterators::WarpTileIterator; + using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; + + /// Hard-coded padding elements added + using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; + + static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::Epilogue< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + WarpTileIterator, + SharedLoadIterator, + OutputOp, + Padding, + kFragmentsPerIteration + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Defines sensible defaults for epilogues for TensorOps which uses /// intereleaved output layout. For this case, shared memory is not needed. template +struct DefaultEpilogueVoltaTensorOpStridedDgrad { + + using Shape = Shape_; + using WarpMmaTensorOp = WarpMmaTensorOp_; + static int const kPartitionsK = PartitionsK; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaTensorOp::LayoutC; + using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess, + ElementAccumulator + >::Type; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< + OutputTileThreadMap, + ElementOutput + >; + + using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< + typename WarpMmaTensorOp::Shape, + gemm::GemmShape<32, 32, 4>, + ElementAccumulator, + LayoutC + >; + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< + typename WarpMmaTensorOp::Shape, + gemm::GemmShape<32, 32, 4>, + ElementAccumulator, + LayoutC + >; + + static int const kSharedMemAlignment = sizeof_bits::value * WarpTileIterator::kElementsPerAccess / 8; + + static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B"); + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< + typename OutputTileThreadMap::CompactedThreadMap, + ElementAccumulator, + kSharedMemAlignment + >; + + /// Hard-coded padding elements added + using Padding = typename WarpTileIterator::Padding; + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::Epilogue< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + WarpTileIterator, + SharedLoadIterator, + OutputOp, + Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + int Rank, + typename Shape_, + typename WarpMmaTensorOp_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultEpilogueVoltaTensorOpAffineRankN { + + using Shape = Shape_; + using WarpMmaTensorOp = WarpMmaTensorOp_; + static int const kPartitionsK = PartitionsK; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaTensorOp::LayoutC; + using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess, + ElementAccumulator + >::Type; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN< + OutputTileThreadMap, + ElementOutput, + Rank + >; + + using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< + typename WarpMmaTensorOp::Shape, + gemm::GemmShape<32, 32, 4>, + ElementAccumulator, + LayoutC + >; + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< + typename WarpMmaTensorOp::Shape, + gemm::GemmShape<32, 32, 4>, + ElementAccumulator, + LayoutC + >; + + static int const kSharedMemAlignment = sizeof_bits::value * WarpTileIterator::kElementsPerAccess / 8; + + static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B"); + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< + typename OutputTileThreadMap::CompactedThreadMap, + ElementAccumulator, + kSharedMemAlignment + >; + + /// Hard-coded padding elements added + using Padding = typename WarpTileIterator::Padding; + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::Epilogue< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + WarpTileIterator, + SharedLoadIterator, + OutputOp, + Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace threadblock } // namespace epilogue } // namespace cutlass diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h new file mode 100644 index 00000000..9f2b263e --- /dev/null +++ b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h @@ -0,0 +1,171 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultEpilogueWithBroadcastTensorOp { + + /// Use defaults related to the existing epilogue + using Base = DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + ElementsPerAccess + >; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + ElementOutput + >; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + ElementTensor + >; + + /// Define the epilogue + using Epilogue = EpilogueWithBroadcast< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputTileIterator, + TensorTileIterator, + ElementVector, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + OutputOp, + typename Base::Padding, + Base::kFragmentsPerIteration + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for VoltaTensorOps. +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultEpilogueWithBroadcastVoltaTensorOp { + + /// Use defaults related to the existing epilogue + using Base = DefaultEpilogueVoltaTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + ElementsPerAccess + >; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + ElementOutput + >; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + ElementTensor + >; + + /// Define the epilogue + using Epilogue = EpilogueWithBroadcast< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputTileIterator, + TensorTileIterator, + ElementVector, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + OutputOp, + typename Base::Padding + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h b/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h new file mode 100644 index 00000000..a63486f0 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename OutputOp, + typename ReductionOp, + int ElementsPerAccess +> +struct DefaultEpilogueWithReductionTensorOp { + + /// Use defaults related to the existing epilogue + using Base = DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + ElementsPerAccess + >; + + /// Additional tensor tile iterator + using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + typename OutputOp::ElementTensor + >; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + ElementOutput + >; + + /// Define the epilogue + using Epilogue = EpilogueWithReduction< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputTileIterator, + TensorTileIterator, + typename WarpMmaTensorOp::ElementC, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + typename Base::OutputOp, + ReductionOp, + typename Base::Padding + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename OutputOp, + typename ReductionOp, + int ElementsPerAccess +> +struct DefaultEpilogueWithReductionVoltaTensorOp { + + /// Use defaults related to the existing epilogue + using Base = DefaultEpilogueVoltaTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + ElementsPerAccess + >; + + /// Additional tensor tile iterator + using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + typename OutputOp::ElementTensor + >; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + ElementOutput + >; + + /// Define the epilogue + using Epilogue = EpilogueWithReduction< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputTileIterator, + TensorTileIterator, + typename WarpMmaTensorOp::ElementC, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + typename Base::OutputOp, + ReductionOp, + typename Base::Padding + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue.h b/include/cutlass/epilogue/threadblock/epilogue.h index 9afd3d5f..042a6d34 100644 --- a/include/cutlass/epilogue/threadblock/epilogue.h +++ b/include/cutlass/epilogue/threadblock/epilogue.h @@ -54,6 +54,7 @@ #include "cutlass/epilogue/threadblock/epilogue_base.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/util/index_sequence.h" //////////////////////////////////////////////////////////////////////////////// @@ -74,7 +75,9 @@ template < typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM typename OutputOp_, ///< Output operator typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) - int FragmentsPerPartition = 1 ///< Used to coarsten the epilogue granularity + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value) > class Epilogue : public EpilogueBase< @@ -141,8 +144,8 @@ public: /// Number of warps using WarpCount = typename Base::WarpCount; - int const kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; - int const kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; public: @@ -194,8 +197,52 @@ public: private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * + (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); - + /// Streams the result to global memory CUTLASS_DEVICE void compute_source_not_needed_( @@ -214,7 +261,7 @@ private: // Iterate over accumulator tile // - CUTLASS_PRAGMA_UNROLL + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { // @@ -224,23 +271,11 @@ private: __syncthreads(); - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { - typename AccumulatorFragmentIterator::Fragment accum_fragment; - - accum_fragment_iterator.load(accum_fragment); - ++accum_fragment_iterator; - - this->warp_tile_iterator_.store(accum_fragment); - - if (p < Base::kFragmentsPerIteration - 1) { - this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset); - } - } - - if (Base::kFragmentsPerIteration > 1) { - this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); - } + acc2smem_source_not_needed< + cutlass::make_index_sequence>::push(iter, + accum_fragment_iterator, + this->warp_tile_iterator_); __syncthreads(); @@ -295,7 +330,34 @@ private: } } } - + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE + static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + /// Streams the result to global memory CUTLASS_DEVICE void compute_source_needed_( @@ -319,7 +381,7 @@ private: // Iterate over accumulator tile // - CUTLASS_PRAGMA_UNROLL + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { // @@ -335,12 +397,8 @@ private: __syncthreads(); - typename AccumulatorFragmentIterator::Fragment accum_fragment; - - accum_fragment_iterator.load(accum_fragment); - ++accum_fragment_iterator; - - this->warp_tile_iterator_.store(accum_fragment); + acc2smem_source_needed>::push( + iter, accum_fragment_iterator, this->warp_tile_iterator_); __syncthreads(); diff --git a/include/cutlass/epilogue/threadblock/epilogue_base.h b/include/cutlass/epilogue/threadblock/epilogue_base.h index 76692d43..d0d85adc 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_base.h +++ b/include/cutlass/epilogue/threadblock/epilogue_base.h @@ -32,6 +32,9 @@ #pragma once +#include +#include + #if defined(__CUDACC_RTC__) #include #else @@ -59,6 +62,32 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// +// +// This is used for metaprogramming epilogue functors. If they define +// `static bool const kIsHeavy = true;`, then the epilogue functor itself is +// not inlined. This results in smaller code and is advantageous if the epilogue +// functor consists of many instructions. +// +// If the epilogue functor does not define `kIsHeavy` or if it is `false`, then +// the behavior from CUTLASS 2.5 and before is retained. The epilogue is fully +// unrolled and inlined. +// + +template +struct TypeSink { typedef void type; }; + +template using TypeSinkT = typename TypeSink::type; + +template struct IsEpilogueFunctorHeavy { + static bool const value = false; +}; + +template struct IsEpilogueFunctorHeavy > { + static bool const value = T::kIsHeavy; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Base class for epilogues defining warp-level template < typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) diff --git a/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h new file mode 100644 index 00000000..2fef4eb7 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h @@ -0,0 +1,207 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/util/index_sequence.h" + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename ElementAccumulator_, + typename ElementOutput_, + typename ThreadBlockShape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + bool ReduceKForA_ +> +class EpilogueGemmKReduction { + +public: + + using ThreadBlockShape = ThreadBlockShape_; + using WarpMmaOperator = WarpMmaOperator_; + using WarpShape = typename WarpMmaOperator::Shape; + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// Accumulator element + using ElementAccumulator = ElementAccumulator_; + + /// Output element + using ElementOutput = ElementOutput_; + + /// Output access size + static int const kElementsPerAccess = 1; + + static bool const kReduceKForA = ReduceKForA_; + + static int const kThreadBlockSize = kReduceKForA ? ThreadBlockShape::kM : ThreadBlockShape::kN; + + static int const kWarpSize = kReduceKForA ? WarpShape::kM : WarpShape::kN; + + static int const kIterations = kWarpSize / 8; + + using FragmentAccumulator = Array; + +private: + + int thread_offset_; + ElementOutput* pointer_; + int col_; +public: + + /// Constructor + CUTLASS_DEVICE + EpilogueGemmKReduction( + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx, ///< Id of thread within warp + int threadblock_offset, + ElementOutput* pointer + ) + { + col_ = lane_idx % 4; + thread_offset_ = threadblock_offset * kThreadBlockSize + + warp_idx * kWarpSize + + lane_idx / 4 + col_ * 8; + + pointer_ = pointer + LongIndex(thread_offset_); + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + int size, + FragmentAccumulator &gemm_k_with_reduction_accumulation, + bool LoadForSerialSplitK + ) { + bool guard[kIterations / 4]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kIterations / 4; ++i) { + guard[i] = ((thread_offset_ + i * 32) < size); + } + + Array source; + source.clear(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kIterations / 4; ++i) { + ElementOutput tmp; + cutlass::arch::global_load( + tmp, + (void *)(pointer_ + i * 32), + guard[i] && LoadForSerialSplitK); + + source[i] = tmp; + } + + FragmentAccumulator sum = gemm_k_with_reduction_accumulation; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kIterations; ++i) { + sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 1); + sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 2); + } + + Array intermediate; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kIterations / 4; ++i) { + if (col_ == 0) { + intermediate[i] = sum[0 + i * 4]; + } + + if (col_ == 1) { + intermediate[i] = sum[1 + i * 4]; + } + + if (col_ == 2) { + intermediate[i] = sum[2 + i * 4]; + } + + if (col_ == 3) { + intermediate[i] = sum[3 + i * 4]; + } + } + + NumericArrayConverter source_converter; + Array converted_source = source_converter(source); + + plus> plus_source; + intermediate = plus_source(intermediate, converted_source); + + NumericArrayConverter converter; + Array result = converter(intermediate); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kIterations / 4; ++i) { + cutlass::arch::global_store(result[i], + (void *)(pointer_ + i * 32), guard[i]); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h new file mode 100644 index 00000000..6c6dc530 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h @@ -0,0 +1,817 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +#include "cutlass/util/index_sequence.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template < + typename ElementC_, + typename ElementAccumulator_, + typename ElementCompute_, + typename ElementZ_, + typename ElementT_, + int ElementsPerAccess, + bool StoreZ = true, + bool StoreT = true +> +struct EpilogueWithBroadcastOpBase { + + using ElementOutput = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = StoreZ; + + /// If true, the 'T' tensor is stored + static bool const kStoreT = StoreT; + + /// Parameters structure - required + struct Params { }; + + // + // Methods + // + + /// Constructor from Params + EpilogueWithBroadcastOpBase(Params const ¶ms_) { } + + /// Determine if the source is needed. May return false if + bool is_source_needed() const { + return true; + } + + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { } + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()( + FragmentZ &frag_Z, + FragmentT &frag_T, + FragmentAccumulator const &AB, + FragmentC const &frag_C, + FragmentCompute const &V) const { + + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()( + FragmentZ &frag_Z, + FragmentT &frag_T, + FragmentAccumulator const &AB, + FragmentCompute const &V) const { + + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator with bias vector broadcast over columns. +/// +/// Computes the following: +/// +/// +/// Z, T = OutputOp(AB, C, Broadcast) +/// +/// if (ElementwiseOp::kStoreZ) { +/// store(converted_u); +/// } +/// +/// if (ElementwiseOp::kStoreT) { +/// store(v); +/// } +/// +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) + typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) + typename ElementVector_, ///< Pointer to broadcast vector + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value) +> +class EpilogueWithBroadcast : + public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + +public: + + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using TensorTileIterator = TensorTileIterator_; + using ElementVector = ElementVector_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Compute data type produced by the output op + using ElementCompute = typename OutputOp::ElementCompute; + + /// Compute fragment + using FragmentCompute = Array; + + /// Thread map used by output tile iterators + using ThreadMap = typename OutputTileIterator::ThreadMap; + + /// Fragment object used to store the broadcast values + using BroadcastFragment = Array< + ElementCompute, + ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Data type of additional tensor + using ElementTensor = typename TensorTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Array type used by output functor + using ComputeAccessType = Array; + + /// Tensor access type + using TensorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + /// Shared memory allocation from epilogue base class + using BaseSharedStorage = typename Base::SharedStorage; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + /// Used for the broadcast + struct BroadcastDetail { + + /// Number of threads per warp + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar column indices handled by each thread + static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar row indices handled by each thread + static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + + /// Number of threads per threadblock + static int const kThreadCount = kWarpSize * WarpCount::kCount; + + /// Number of distinct threads per row of output tile + static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); + + /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. + static int const kThreadRows = kThreadCount / kThreadsPerRow; + + /// I'm not sure what I meant here. + static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = MatrixShape< + kThreadRows, + Shape::kN + >; + + /// Debug printing + CUTLASS_DEVICE + static void print() { + printf("BroadcastDetail {\n"); + printf( + " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" + "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", + kColumnsPerThread, + kRowsPerThread, + kThreadCount, + kThreadsPerRow, + kThreadRows, + kThreadAccessesPerRow, + StorageShape::kRow, + StorageShape::kColumn, + StorageShape::kCount + ); + printf("};\n"); + } + }; + + /// Shared storage structure (shadows base) with additional SMEM buffer for reduction + struct SharedStorage { + union { + BaseSharedStorage base; + }; + + CUTLASS_HOST_DEVICE + SharedStorage() { } + }; + +public: + + + static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + /// Thread index within the threadblock + int thread_idx_; + +public: + + /// Constructor + CUTLASS_DEVICE + EpilogueWithBroadcast( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + Base(shared_storage.base, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.base.reference(), thread_idx), + thread_idx_(thread_idx) + { + + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + ElementVector const * broadcast_ptr, ///< Broadcast vector + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix + TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) { + + BroadcastFragment broadcast_fragment; + + load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); + + if (!output_op.is_source_needed()) { + compute_source_not_needed_( + output_op, + broadcast_fragment, + destination_iterator, + accumulators, + tensor_iterator); + } + else { + compute_source_needed_( + output_op, + broadcast_fragment, + destination_iterator, + accumulators, + source_iterator, + tensor_iterator); + } + } + +private: + + CUTLASS_DEVICE + void load_broadcast_fragment_( + BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + ElementVector const * broadcast_ptr, ///< Broadcast vector + MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space + ) { + + broadcast_fragment.clear(); + + // If no pointer is supplied, set with all zeros and avoid memory accesses + if (!broadcast_ptr) { + return; + } + + int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); + + int thread_column_idx = threadblock_offset.column() + thread_initial_column; + broadcast_ptr += thread_initial_column; + + NumericArrayConverter converter; + using AccessType = AlignedArray; + using ComputeFragmentType = Array; + + ComputeFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { + + AccessType loaded; + + loaded.clear(); + + if (thread_column_idx < problem_size.column()) { + loaded = *reinterpret_cast(broadcast_ptr); + } + + ComputeFragmentType cvt = converter(loaded); + frag_ptr[j] = cvt; + + thread_column_idx += ThreadMap::Delta::kColumn; + broadcast_ptr += ThreadMap::Delta::kColumn; + } + } + + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * + (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const &output_op, ///< Output operator + BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) { + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + // CUTLASS_PRAGMA_UNROLL + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { + + // + // Convert and store fragment + // + + + __syncthreads(); + + acc2smem_source_not_needed< + cutlass::make_index_sequence>::push(iter, + accum_fragment_iterator, + this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } + else if (kPartitionsK > 1) { + + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Apply output operation + // + + typename OutputTileIterator::Fragment frag_Z; + typename TensorTileIterator::Fragment frag_T; + + apply_output_operator_source_not_needed_( + frag_Z, + frag_T, + output_op, + aligned_accum_fragment[0], + broadcast_fragment); + + // + // Conditionally store fragments + // + + if (OutputOp::kStoreZ) { + destination_iterator.store(frag_Z); + ++destination_iterator; + } + + if (OutputOp::kStoreT) { + tensor_iterator.store(frag_T); + ++tensor_iterator; + } + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE + static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const &output_op, ///< Output operator + BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) { + + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + source_iterator.load(source_fragment); + ++source_iterator; + + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_needed>::push( + iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + if (kPartitionsK > 1) + { + plus add_fragments; + const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); + } + + // + // Apply output operation + // + + typename OutputTileIterator::Fragment frag_Z; + typename TensorTileIterator::Fragment frag_T; + + apply_output_operator_( + frag_Z, + frag_T, + output_op, + aligned_accum_fragment[0], + source_fragment, + broadcast_fragment); + + // + // Conditionally store fragments + // + + if (OutputOp::kStoreZ) { + destination_iterator.store(frag_Z); + ++destination_iterator; + } + + if (OutputOp::kStoreT) { + tensor_iterator.store(frag_T); + ++tensor_iterator; + } + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment &frag_Z, + typename TensorTileIterator::Fragment &frag_T, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &frag_AB, + typename OutputTileIterator::Fragment const &frag_C, + BroadcastFragment const &frag_Broadcast) { + + using AccessTypeZ = Array; + using AccessTypeT = Array; + using AccessTypeBroadcast = Array; + + AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); + AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); + + AccumulatorAccessType const *frag_AB_ptr = + reinterpret_cast(&frag_AB); + + OutputAccessType const *frag_C_ptr = + reinterpret_cast(&frag_C); + + AccessTypeBroadcast const *frag_Broadcast_ptr = + reinterpret_cast(&frag_Broadcast); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + output_op( + frag_Z_ptr[i], + frag_T_ptr[i], + frag_AB_ptr[i], + frag_C_ptr[i], + frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + typename OutputTileIterator::Fragment &frag_Z, + typename TensorTileIterator::Fragment &frag_T, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &frag_AB, + BroadcastFragment const &frag_Broadcast) { + + using AccessTypeZ = Array; + using AccessTypeT = Array; + using AccessTypeBroadcast = Array; + + AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); + AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); + + AccumulatorAccessType const *frag_AB_ptr = + reinterpret_cast(&frag_AB); + + AccessTypeBroadcast const *frag_Broadcast_ptr = + reinterpret_cast(&frag_Broadcast); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + output_op( + frag_Z_ptr[i], + frag_T_ptr[i], + frag_AB_ptr[i], + frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h new file mode 100644 index 00000000..ae242629 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h @@ -0,0 +1,728 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator with reduction over each column +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands + typename ElementVector_, ///< Pointer to reduction vector + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + typename OutputOp_, ///< Output operator + typename ReductionOp_, ///< Reduction operator + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value) +> +class EpilogueWithReduction : + public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_> { + +public: + + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using TensorTileIterator = TensorTileIterator_; + using ElementVector = ElementVector_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using ReductionOp = ReductionOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Compute data type produced by the output op + using ElementCompute = typename OutputOp::ElementCompute; + + /// Compute fragment + using FragmentCompute = Array; + + /// Thread map used by output tile iterators + using ThreadMap = typename OutputTileIterator::ThreadMap; + + /// Fragment object used in reduction + using ReductionFragment = Array< + ElementAccumulator, + ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Data type of additional tensor + using ElementTensor = typename TensorTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Array type used by output functor + using ComputeAccessType = Array; + + /// Tensor access type + using TensorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + /// Shared memory allocation from epilogue base class + using BaseSharedStorage = typename Base::SharedStorage; + + /// Used for the reduction + struct ReductionDetail { + + /// Number of threads per warp + static int const kWarpSize = 32; + + /// Number of distinct scalar column indices handled by each thread + static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar row indices handled by each thread + static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + + /// Number of threads per threadblock + static int const kThreadCount = kWarpSize * WarpCount::kCount; + + /// Number of distinct threads per row of output tile + static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); + + /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. + static int const kThreadRows = kThreadCount / kThreadsPerRow; + + /// I'm not sure what I meant here. + static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = MatrixShape< + kThreadRows, + Shape::kN + >; + + /// Debug printing + CUTLASS_DEVICE + static void print() { + printf("ReductionDetail {\n"); + printf( + " kElementsPerAccess:%d\nkColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" + "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", + kElementsPerAccess, + kColumnsPerThread, + kRowsPerThread, + kThreadCount, + kThreadsPerRow, + kThreadRows, + kThreadAccessesPerRow, + StorageShape::kRow, + StorageShape::kColumn, + StorageShape::kCount + ); + printf("};\n"); + } + }; + + /// Shared storage structure (shadows base) with additional SMEM buffer for reduction + struct SharedStorage { + union { + BaseSharedStorage base; + AlignedArray reduction; ///< Shared storage for reduction + }; + + CUTLASS_HOST_DEVICE + SharedStorage() { } + }; + +public: + + + static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + /// Shared memory pointer fo rreduction + ElementAccumulator *reduction_ptr_; + + /// Thread index within the threadblock + int thread_idx_; + +public: + + /// Constructor + CUTLASS_DEVICE + EpilogueWithReduction( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + Base(shared_storage.base, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.base.reference(), thread_idx), + reduction_ptr_(shared_storage.reduction.data()), + thread_idx_(thread_idx) + { + + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + ElementVector * reduction_output_ptr, ///< Reduction output vector + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix + TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) { + + ReductionFragment reduction_fragment; + reduction_fragment.clear(); + + if (!output_op.is_source_needed()) { + compute_source_not_needed_( + output_op, + reduction_fragment, + destination_iterator, + accumulators, + tensor_iterator); + } + else { + compute_source_needed_( + output_op, + reduction_fragment, + destination_iterator, + accumulators, + source_iterator, + tensor_iterator); + } + + if (output_op.participates_in_reduction()) { + reduction_(problem_size, threadblock_offset, reduction_output_ptr, reduction_fragment); + } + } + +private: + + /// Perform the reduction + CUTLASS_DEVICE + void reduction_( + MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord const &threadblock_offset, ///< Problem size needed to guard against out-of-bounds accesses + ElementVector * reduction_output_ptr, ///< Reduction output vector + ReductionFragment const & reduction_fragment) { + + // + // Store the partially reduced value to SMEM + // + + // Guard against uses of the existing SMEM tile + __syncthreads(); + + using AccessType = AlignedArray; + + // + // Determine a compacted thread arrangement to store to SMEM. + // + int const kThreadsPerRow = Shape::kN / (ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess); + + MatrixCoord thread_offset( + thread_idx_ / kThreadsPerRow, + (thread_idx_ % kThreadsPerRow) * ThreadMap::kElementsPerAccess); + + // + // Each thread store its fragment to a SMEM + // + + AccessType *aligned_reduction_ptr = reinterpret_cast( + &reduction_ptr_[thread_offset.row() * Shape::kN + thread_offset.column()]); + + AccessType const *frag_ptr = reinterpret_cast(&reduction_fragment); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess; + + aligned_reduction_ptr[col_idx] = frag_ptr[column]; + } + + __syncthreads(); + + // + // Now, threads are assigned several columns of the output. They fetch over all rows from + // the compacted SMEM tile and perform a reduction. + // + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) { + int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount; + + ReductionOp reduction_op; + ElementAccumulator reduction_element = ElementAccumulator(); + + int output_column_idx = threadblock_offset.column() + column_idx; + + if (column_idx < Shape::kN && output_column_idx < problem_size.column()) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ReductionDetail::kThreadRows; ++row) { + if (row) { + auto frag = reduction_ptr_[row * Shape::kN + column_idx]; + + reduction_element = reduction_op(reduction_element, frag); + } + else { + + reduction_element = reduction_ptr_[column_idx]; + } + } + + // Store + reduction_output_ptr[column_idx] = ElementVector(reduction_element); + } + } + } + + template + struct acc2smem; + + template + struct acc2smem> { + template + CUTLASS_DEVICE + static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const &output_op, ///< Output operator + ReductionFragment &reduction_fragment, ///< Fragment containing the accumulated partial reduction over columns + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) { + + // + // Iterator over warp-level accumulator fragment + // + + typename TensorTileIterator::Fragment tensor_fragment; + tensor_fragment.clear(); + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Convert and store fragment + // + + tensor_iterator.load(tensor_fragment); + ++tensor_iterator; + + __syncthreads(); + + acc2smem>::push( + iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + // + if (kPartitionsK > 1) + { + plus add_fragments; + const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); + } + + // + // Compute the output result + // + + FragmentCompute compute_fragment; + + apply_output_operator_source_not_needed_( + reduction_fragment, + compute_fragment, + output_op, + aligned_accum_fragment[0], + tensor_fragment); + + // + // Store the final result + // + + NumericArrayConverter converter; + + typename OutputTileIterator::Fragment output_fragment = converter(compute_fragment); + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const &output_op, ///< Output operator + ReductionFragment &reduction_fragment, ///< Fragment containing the accumulated partial reduction over columns + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) { + + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); + + typename TensorTileIterator::Fragment tensor_fragment; + tensor_fragment.clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + source_fragment.clear(); + source_iterator.load(source_fragment); + ++source_iterator; + + tensor_iterator.load(tensor_fragment); + ++tensor_iterator; + + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem>::push( + iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + if (kPartitionsK > 1) + { + plus add_fragments; + const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); + } + + // + // Compute the output result + // + + FragmentCompute compute_fragment; + + apply_output_operator_( + reduction_fragment, + compute_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment, + tensor_fragment); + + // + // Convert and store the final result + // + + NumericArrayConverter converter; + + typename OutputTileIterator::Fragment output_fragment = converter(compute_fragment); + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + ReductionFragment &reduction_fragment, + FragmentCompute &compute_fragment, + OutputOp const &output_op, ///< Output operator + typename SharedLoadIterator::Fragment const &aligned_accum_fragment, + typename OutputTileIterator::Fragment const &source_fragment, + typename TensorTileIterator::Fragment const &tensor_fragment) { + + ComputeAccessType *compute_frag_ptr = + reinterpret_cast(&compute_fragment); + + AccumulatorAccessType const *accum_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + OutputAccessType const *source_frag_ptr = + reinterpret_cast(&source_fragment); + + TensorAccessType const *tensor_frag_ptr = + reinterpret_cast(&tensor_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + // Call the output operator + compute_frag_ptr[i] = output_op(accum_frag_ptr[i], source_frag_ptr[i], tensor_frag_ptr[i]); + } + + // + // Partial reduction over each column + // + + ReductionOp reduction_op; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ReductionDetail::kColumnsPerThread; ++column) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ReductionDetail::kRowsPerThread; ++row) { + reduction_fragment[column] = reduction_op( + reduction_fragment[column], + compute_fragment[row * ReductionDetail::kColumnsPerThread + column]); + } + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + ReductionFragment &reduction_fragment, + FragmentCompute &compute_fragment, + OutputOp const &output_op, ///< Output operator + typename SharedLoadIterator::Fragment const &aligned_accum_fragment, + typename TensorTileIterator::Fragment const &tensor_fragment) { + + ComputeAccessType *compute_frag_ptr = + reinterpret_cast(&compute_fragment); + + AccumulatorAccessType const *accum_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + TensorAccessType const *tensor_frag_ptr = + reinterpret_cast(&tensor_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + // Call the output operator + compute_frag_ptr[i] = output_op(accum_frag_ptr[i], tensor_frag_ptr[i]); + } + + // + // Partial reduction over each column + // + + ReductionOp reduction_op; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ReductionDetail::kColumnsPerThread; ++column) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ReductionDetail::kRowsPerThread; ++row) { + reduction_fragment[column] = reduction_op( + reduction_fragment[column], + compute_fragment[row * ReductionDetail::kColumnsPerThread + column]); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h index 7bf7b4de..66ca4f30 100644 --- a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h +++ b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h @@ -253,7 +253,7 @@ class InterleavedEpilogue { // typename OutputTileIterator::Fragment output_fragment; - apply_output_operator_(output_op, output_fragment, accum_fragment, source_fragment); + apply_output_operator_source_needed_(output_op, output_fragment, accum_fragment, source_fragment); // // Store the final result @@ -268,7 +268,7 @@ class InterleavedEpilogue { private: /// Helper to invoke the output functor over each vector of output CUTLASS_DEVICE - void apply_output_operator_( + void apply_output_operator_source_needed_( OutputOp const &output_op, ///< Output operator typename OutputTileIterator::Fragment &output_fragment, typename AccumulatorFragmentIterator::Fragment const diff --git a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h index 377f33bd..b234520f 100644 --- a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h +++ b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h @@ -164,7 +164,7 @@ template < > struct RowArrangement { - static int const kMemoryAccessSize = 128; + static int const kMemoryAccessSize = 256; // Preferred access size static int const kWarpSize = 32; static int const kElementsPerAccess = ElementsPerAccess; diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index a4a5d15a..4ac307ae 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -56,7 +56,7 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -/// Tile iterator used to load and store output tile from shared memory in epilogue. +/// Tile iterator used to load and store output tile from global memory in epilogue. /// /// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator /// @@ -105,6 +105,7 @@ public: /// Uses a non-template class struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; CUTLASS_HOST_DEVICE Params() { } @@ -115,9 +116,11 @@ public: layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, make_OutputTileThreadMapDesc() ) - { - - } + { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } }; /// Mask object @@ -177,6 +180,14 @@ private: /// Internal state counter int state_[3]; + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + private: // @@ -236,7 +247,7 @@ public: /// Loads a fragment from memory CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { uint8_t *byte_pointer = byte_pointer_; AccessType *frag_ptr = reinterpret_cast(&frag); @@ -292,18 +303,16 @@ public: } } } - - /// Loads a fragment from memory CUTLASS_DEVICE - void load(Fragment &frag) { + void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } /// Stores a fragment to memory CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { uint8_t *byte_pointer = byte_pointer_; AccessType const *frag_ptr = reinterpret_cast(&frag); @@ -357,7 +366,7 @@ public: /// Stores a fragment to memory CUTLASS_DEVICE - void store(Fragment const &frag) { + void store(Fragment const &frag) const { store_with_byte_offset(frag, 0); } @@ -421,7 +430,7 @@ public: //////////////////////////////////////////////////////////////////////////////// -/// Tile iterator used to load output tile from shared memory in epilogue. +/// Tile iterator used to load output tile from global memory in epilogue. /// /// Satisfies: ReadableTileIterator | InterleavedPredicatedTileIterator | ForwardTileIterator /// @@ -454,51 +463,23 @@ public: /// Memory access size using AccessType = AlignedArray; - // - // Parameters struct - // - - struct Params { - - // - // Data members - // - - LongIndex stride; ///< stride in bytes between columns - - LongIndex advance_row; ///< amount to add to move to the next 'row' position - LongIndex advance_column; ///< amount to add to move to the next 'column' position - - // - // Methods - // + /// Uses a non-template class + struct Params : InterleavedPredicatedTileIteratorParams { + using Base = InterleavedPredicatedTileIteratorParams; CUTLASS_HOST_DEVICE - Status initialize(Index stride_) { - - stride = LongIndex(stride_); - - advance_row = - ThreadMap::Delta::kContiguous * sizeof_bits::value / 8; - - advance_column = LongIndex(stride_) - ThreadMap::Iterations::kContiguous * - kElementsPerAccess * - sizeof_bits::value * - ThreadMap::kWarpSize / 8; - - return Status::kSuccess; - } + Params() { } CUTLASS_HOST_DEVICE - Params() { - initialize(0); - } + Params(Layout const &layout): + Base( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_InterleavedPredicatedTileIteratorDesc() + ) { } CUTLASS_HOST_DEVICE - Params(Layout const &layout) { - - initialize(layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess); - } + Params(Base const &base) : + Base(base) { } }; /// Mask object @@ -705,7 +686,7 @@ public: /////////////////////////////////////////////////////////////////////////////// -/// Tile iterator used to load output tile from shared memory in epilogue. +/// Tile iterator used to load output tile from global memory in epilogue. /// /// Satisfies: ReadableTileIterator | InterleavedMaskedTileIterator | ForwardTileIterator /// diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h new file mode 100644 index 00000000..e93d736e --- /dev/null +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h @@ -0,0 +1,602 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +/// It provides a fast path for the case Rank = 2 which does not need div/rem to +/// calculate modes. + +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + int Rank +> +class PredicatedTileIteratorAffineRankN { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::AffineRankN; + using TensorRef = TensorRef; + using TensorView = TensorView; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = typename Layout::TensorCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + static_assert( !(Layout::kRank % 2), + "Layout rank must be even. This assumes the first half of the modes correspond to the 'row' " + "and the second half of the modes correspond to the 'column'"); + + static bool const kBigEndian = false; + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Parameters structure + struct Params { + + // + // Data members + // + + Layout layout; + + /// Stride in units of bytes along M modes + Coord stride_m; + + /// Stride in units of bytes along N modes + Coord stride_n; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; + + int64_t rank2_inc_col; + int64_t rank2_inc_row; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(TensorCoord const &extent, Layout const &layout_): layout(layout_) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i]); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); + } + + if (kBigEndian) { + // "Big Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i + 1]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]); + } + } + else { + // "Little Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]); + } + } + + #if 0 + // + // Debug print statements to verify extents and strides are passed correctly. + // + printf("PredicatedTileIteratorAffine::Params() entered\n"); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + printf(" extent[%d]: %d\n", i, extent[i]); + } + for (int i = 0; i < Layout::kRank; ++i) { + printf(" stride[%d]: %ld\n", i, layout_.stride()[i]); + } + printf("PredicatedTileIteratorAffine::Params() returning\n"); + #endif + } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout_): layout(layout_) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i]); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); + } + + rank2_inc_col = ThreadMap::Delta::kColumn * stride_n[0]; + rank2_inc_row = ThreadMap::Delta::kRow * stride_m[0]; + } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_col_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column position (assuming steady-state predicates have been computed) + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorAffineRankN( + Params const & params, + Element *pointer, + MatrixCoord extent, + int thread_idx, + MatrixCoord threadblock_offset = MatrixCoord() + ): + params_(params) + { + + MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_col_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + if (Layout::kRank > 2) { + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + if (!pointer) { + mask_.clear(); + } + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer); + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { + uint8_t const *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; + int64_t offset_modes_m = row_begin * params_.stride_m[0]; + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + // + // Compute coordinate and decompose into M modes + // + + int coord_m = row * ThreadMap::Delta::kRow + row_begin; + + Coord modes_m; + + if (Layout::kRank > 2) { + if (kBigEndian) { + modes_m = CoordinateDecomposition(coord_m, params_.divmod_m); + } else { + modes_m = CoordinateDecompositionLittleEndian(coord_m, params_.divmod_m); + } + + offset_modes_m = dot(modes_m, params_.stride_m); + } + + // + // Compute the offset due to modes M + // + + bool row_guard = (coord_m < extent_row_); + int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0]; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + // + // Compute coordinate and decompose into N modes + // + + int coord_n = thread_start_column_ + column * ThreadMap::Delta::kColumn; + + Coord modes_n; + + if (Layout::kRank > 2) { + if (kBigEndian) { + modes_n = CoordinateDecomposition(coord_n, params_.divmod_n); + } else { + modes_n = CoordinateDecompositionLittleEndian(coord_n, params_.divmod_n); + } + + offset_modes_n = dot(modes_n, params_.stride_n); + } + + // + // Compute the pointer and access + // + bool guard; + + if (Layout::kRank > 2) { + guard = row_guard && mask_.predicates[column]; + } else { + guard = (coord_m < extent_row_) && + ((thread_start_column_ + ThreadMap::Delta::kColumn * column) < extent_col_); + } + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset), + guard + ); + + if (Layout::kRank == 2) { + offset_modes_n += params_.rank2_inc_col; + } + } + + if (Layout::kRank == 2) { + offset_modes_m += params_.rank2_inc_row; + } + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; + int64_t offset_modes_m = row_begin * params_.stride_m[0]; + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + // + // Compute coordinate and decompose into M modes + // + + int coord_m = row * ThreadMap::Delta::kRow + row_begin; + + Coord modes_m; + + if (Layout::kRank > 2) { + if (kBigEndian) { + modes_m = CoordinateDecomposition(coord_m, params_.divmod_m); + } else { + modes_m = CoordinateDecompositionLittleEndian(coord_m, params_.divmod_m); + } + + offset_modes_m = dot(modes_m, params_.stride_m); + } + + // + // Compute the offset due to modes M + // + + bool row_guard = (coord_m < extent_row_); + int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0]; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + // + // Compute coordinate and decompose into N modes + // + + int coord_n = thread_start_column_ + column * ThreadMap::Delta::kColumn; + + Coord modes_n; + + if (Layout::kRank > 2) { + if (kBigEndian) { + modes_n = CoordinateDecomposition(coord_n, params_.divmod_n); + } + else { + modes_n = CoordinateDecompositionLittleEndian(coord_n, params_.divmod_n); + } + + offset_modes_n = dot(modes_n, params_.stride_n); + } + + // + // Compute the pointer and access + // + bool guard; + if (Layout::kRank > 2) { + guard = row_guard && mask_.predicates[column]; + } else { + guard = (coord_m < extent_row_) && ((thread_start_column_ + ThreadMap::Delta::kColumn * column) < extent_col_); + } + + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset), + guard); + + if (Layout::kRank == 2) { + offset_modes_n += params_.rank2_inc_col; + } + } + + if (Layout::kRank == 2) { + offset_modes_m += params_.rank2_inc_row; + } + } + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + + store_with_byte_offset(frag, 0); + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAffineRankN &operator++() { + + ++state_[0]; + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h index d73ce1bd..300379cd 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h @@ -138,11 +138,10 @@ OutputTileThreadMapDesc make_OutputTileThreadMapDesc() { make_OutputTileShapeDesc() ); } - -///////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// // -// Parameters struct +// Parameters struct for PredicatedTileIterator // struct PredicatedTileIteratorParams { @@ -170,9 +169,9 @@ struct PredicatedTileIteratorParams { // CUTLASS_HOST_DEVICE - Status initialize(Index stride_, OutputTileThreadMapDesc thread_map) { + Status initialize(LongIndex stride_, OutputTileThreadMapDesc thread_map) { - stride = LongIndex(stride_); + stride = stride_; increment_row = stride * thread_map.delta.row; @@ -206,19 +205,166 @@ struct PredicatedTileIteratorParams { return Status::kSuccess; } + CUTLASS_HOST_DEVICE + Status initialize(Index stride_, OutputTileThreadMapDesc thread_map) { + return initialize(LongIndex(stride_), thread_map); + } + CUTLASS_HOST_DEVICE PredicatedTileIteratorParams() { - initialize(0, OutputTileThreadMapDesc()); + initialize(LongIndex(0), OutputTileThreadMapDesc()); } CUTLASS_HOST_DEVICE PredicatedTileIteratorParams(Index stride, OutputTileThreadMapDesc thread_map) { + initialize(stride, thread_map); + } + CUTLASS_HOST_DEVICE + PredicatedTileIteratorParams(LongIndex stride, OutputTileThreadMapDesc thread_map) { initialize(stride, thread_map); } }; + /////////////////////////////////////////////////////////////////////////////// +// InterleavedPredicatedTileIterator +/////////////////////////////////////////////////////////////////////////////// + + +/// Predicated tile access iterator descriptor object containing template dependent state +struct InterleavedPredicatedTileIteratorDesc { + + int element_size_bits; + int elements_per_access; + int threadmap_warp_size; + layout::PitchLinearCoord threadmap_iterations; + layout::PitchLinearCoord threadmap_delta; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorDesc() { } + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorDesc( + int element_size_bits_, + int elements_per_access_, + int threadmap_warp_size_, + layout::PitchLinearCoord threadmap_iterations_, + layout::PitchLinearCoord threadmap_delta_ + ): + element_size_bits(element_size_bits_), + elements_per_access(elements_per_access_), + threadmap_warp_size(threadmap_warp_size_), + threadmap_iterations(threadmap_iterations_), + threadmap_delta(threadmap_delta_) { } +}; + +// +// Parameters struct InterleavedPredicatedTileIterator +// + +struct InterleavedPredicatedTileIteratorParams { + + using Index = int32_t; + using LongIndex = int64_t; + + // + // Data members + // + + LongIndex stride; ///< stride in bytes between rows + LongIndex advance_row; ///< amount to add to move to the next 'row' position + LongIndex advance_column; ///< amount to add to move to the next 'column' position + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Status initialize(LongIndex stride_, InterleavedPredicatedTileIteratorDesc desc) { + + stride = stride_; + + advance_row = desc.threadmap_delta.contiguous() * desc.element_size_bits / 8; + + advance_column = stride_ - desc.threadmap_iterations.contiguous() * + desc.elements_per_access * + desc.element_size_bits * + desc.threadmap_warp_size / 8; + + return Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorParams() { + initialize(LongIndex(0), InterleavedPredicatedTileIteratorDesc()); + } + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorParams(Index stride, InterleavedPredicatedTileIteratorDesc desc) { + initialize(stride, desc); + } + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorParams(LongIndex stride, InterleavedPredicatedTileIteratorDesc desc) { + initialize(stride, desc); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper template to construct an OutputTileShapeDesc from a OutputTileThreadMap template. +template +CUTLASS_HOST_DEVICE +InterleavedPredicatedTileIteratorDesc make_InterleavedPredicatedTileIteratorDesc() { + return InterleavedPredicatedTileIteratorDesc( + sizeof_bits::value, + ThreadMap::kElementsPerAccess, + ThreadMap::kWarpSize, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper template to construct an MakePredicatedTileIteratorDesc from a template +// dependent state +template + struct MakePredicatedTileIteratorDesc; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for layout::RowMajor output data. +template +struct MakePredicatedTileIteratorDesc < + Element, layout::RowMajor, ThreadMap> { + + CUTLASS_HOST_DEVICE + OutputTileThreadMapDesc operator()() { + + return make_OutputTileThreadMapDesc(); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for layout::ColumnMajorInterleaved output data. +template +struct MakePredicatedTileIteratorDesc < + Element, layout::ColumnMajorInterleaved, ThreadMap> { + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorDesc operator()() { + + return make_InterleavedPredicatedTileIteratorDesc(); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace threadblock } // namespace epilogue diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h new file mode 100644 index 00000000..35242ab4 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h @@ -0,0 +1,303 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief PredicatedTileIteratorPredicates. + + PredicatedTileIteratorPredicates enables both upper and lower bounds for predicates. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator predicates used to bound computations in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_ ///< Element data type +> +class PredicatedTileIteratorPredicates { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout): + PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) + { + + } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index lower_extent_row_; + Index upper_extent_row_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// Internal state counter + int state_[3]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(lower_extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(upper_extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPredicates( + PredicatedTileIteratorParams const & params, + TensorCoord lower_extent, + TensorCoord upper_extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord() + ): + params_(params) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + lower_extent_row_ = lower_extent.row(); + upper_extent_row_ = upper_extent.row(); + thread_start_row_ = thread_offset.row(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < upper_extent.column()) && + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) >= lower_extent.column()); + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPredicates &operator++() { + + ++state_[0]; + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Gets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } + + ///< Gets lower_extent_row_ + CUTLASS_DEVICE Index get_lower_extent_row() { + return lower_extent_row_; + } + + ///< Gets upper_extent_row_ + CUTLASS_DEVICE Index get_upper_extent_row() { + return upper_extent_row_; + } + + ///< Gets thread_start_row_ + CUTLASS_DEVICE Index get_thread_start_row() { + return thread_start_row_; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h new file mode 100644 index 00000000..56b1b55c --- /dev/null +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h @@ -0,0 +1,469 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_ ///< Element data type +> +class PredicatedTileIteratorStridedDgrad { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + + /// Convolution problem size + cutlass::conv::Conv2dProblemSize problem_size; + int tiled_rows_per_filter; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, cutlass::conv::Conv2dProblemSize problem_size_, int threadblock_row): + problem_size(problem_size_), + PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) + { + + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_row); + + tiled_rows_per_filter = tile_m_per_filter * threadblock_row; + } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Starting Dx h and w dimenstion for strided dgrad mapping + int start_h_, start_w_; + + /// Effective Dy P and Q dimenstions for strided dgrad mapping + int p_, q_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column position (assuming steady-state predicates have been computed) + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorStridedDgrad( + Params const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + int start_r, int start_s, + TensorCoord threadblock_offset = TensorCoord() + ): + params_(params) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + int r = start_r; + int s = start_s; + + if (params_.problem_size.mode == cutlass::conv::Mode::kConvolution) { + r = (params_.problem_size.R - 1 - r); + s = (params_.problem_size.S - 1 - s); + } + + // check if start_h_ and start_w_ are always positive + start_h_ = std::abs((params_.problem_size.pad_h - r) % params_.problem_size.stride_h); + start_w_ = std::abs((params_.problem_size.pad_w - s) % params_.problem_size.stride_w); + + p_ = (params_.problem_size.H - start_h_ + params_.problem_size.stride_h - 1) / params_.problem_size.stride_h; + q_ = (params_.problem_size.W - start_w_ + params_.problem_size.stride_w - 1) / params_.problem_size.stride_w; + + extent_row_ = extent.row(); + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer); + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + // remapping rows to find the mapped_row_offset + int npq_offset = (row_offset + thread_start_row_) % params_.tiled_rows_per_filter; + + // (STEP 4.a) [order NHW rows to be loaded and stored in output Dx NHWxC layout] + int n = npq_offset / (p_ * q_); + int residual = npq_offset % (p_ * q_); + int p = residual / q_; + int q = residual % q_; + + int mapped_row_offset = n * (params_.problem_size.H * params_.problem_size.W) + + (start_h_ + p * params_.problem_size.stride_h) * params_.problem_size.W + + (start_w_ + q * params_.problem_size.stride_w); + bool row_guard = mapped_row_offset < extent_row_; + + int64_t row_byte_offset = mapped_row_offset * params_.stride; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int64_t column_byte_offset = (thread_start_column_ + column * ThreadMap::Delta::kColumn) * (sizeof_bits::value / 8); + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)(byte_pointer + row_byte_offset + column_byte_offset + byte_offset), + guard); + } + } + } + } + } + + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + // remapping rows to find the mapped_row_offset + int npq_offset = (row_offset + thread_start_row_) % params_.tiled_rows_per_filter; + + // (STEP 4.a) [order NHW rows to be loaded and stored in output Dx NHWxC layout] + int n = npq_offset / (p_ * q_); + int residual = npq_offset % (p_ * q_); + int p = residual / q_; + int q = residual % q_; + + int mapped_row_offset = n * (params_.problem_size.H * params_.problem_size.W) + + (start_h_ + p * params_.problem_size.stride_h) * params_.problem_size.W + + (start_w_ + q * params_.problem_size.stride_w); + bool row_guard = mapped_row_offset < extent_row_; + + int64_t row_byte_offset = mapped_row_offset * params_.stride; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int64_t column_byte_offset = (thread_start_column_ + column * ThreadMap::Delta::kColumn) * (sizeof_bits::value / 8); + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)(byte_pointer + row_byte_offset + column_byte_offset + byte_offset), + guard); + } + } + } + } + } + + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + + store_with_byte_offset(frag, 0); + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorStridedDgrad &operator++() { + + ++state_[0]; + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator.h b/include/cutlass/epilogue/threadblock/shared_load_iterator.h index b5fefa26..bb43fa81 100644 --- a/include/cutlass/epilogue/threadblock/shared_load_iterator.h +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator.h @@ -158,7 +158,7 @@ public: /// Loads a fragment from memory CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { CUTLASS_PRAGMA_UNROLL @@ -200,7 +200,7 @@ public: /// Loads a fragment CUTLASS_DEVICE - void load(Fragment &frag) { + void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h b/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h index 5b31e337..decafd8d 100644 --- a/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h @@ -158,7 +158,7 @@ public: pointers_[i] = reinterpret_cast(ref.data()); int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; - int bank_offset = (col_idx * sizeof(LoadType) / 128) % kLoadsPerAccess; + int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; col_idx += (bank_offset + i) % kLoadsPerAccess; @@ -187,7 +187,7 @@ public: /// Loads a fragment from memory CUTLASS_DEVICE - void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { @@ -230,7 +230,7 @@ public: /// Loads a fragment CUTLASS_DEVICE - void load(Fragment &frag) { + void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } diff --git a/include/cutlass/epilogue/warp/simt_policy.h b/include/cutlass/epilogue/warp/simt_policy.h index 058a6c44..cecac9d1 100644 --- a/include/cutlass/epilogue/warp/simt_policy.h +++ b/include/cutlass/epilogue/warp/simt_policy.h @@ -84,6 +84,12 @@ struct SimtPolicy { /// Number of accesses made in one iteration static int const kAccessesPerIteration = kElementsPerIteration / kElementsPerAccess; + + /// Number of elements in between accumulator chunks of (LaneMmaShape::kM x LaneMmaShape::kN) + using Delta = MatrixShape< + MmaSimtPolicy::WarpShape::kRow * MmaSimtPolicy::LaneMmaShape::kM, + MmaSimtPolicy::WarpShape::kColumn * MmaSimtPolicy::LaneMmaShape::kN + >; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/warp/tile_iterator_simt.h b/include/cutlass/epilogue/warp/tile_iterator_simt.h index 552f15b3..96511886 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_simt.h +++ b/include/cutlass/epilogue/warp/tile_iterator_simt.h @@ -238,6 +238,247 @@ public: ///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename Operator_, ///< matrix multiply operation (concept: arch::Mma) + typename Element_, ///< data type of element to be written + typename Layout_, ///< target shared memory layout + typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) +> +class TileIteratorSimtCanonical { +public: + + using WarpShape = WarpShape_; + using Operator = Operator_; + using Element = Element_; + using Layout = Layout_; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = SimtPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + typename Operator::ElementC, + Policy::kElementsPerIteration>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + typename Operator::ElementC, + Policy::kAccumulatorElementCount>; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + /// Padding quantity + using Padding = MatrixShape< + 0, + 4 * Policy::kElementsPerAccess + 1 + >; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray< + Element, + 1 + >; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + + /// Guard to indicate whether the shape is divisible + bool divisible_; + + /// Extent of the output tensor + MatrixCoord extent_; + + /// Thread offset + MatrixCoord thread_offset_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical(): pointer_(nullptr) { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / AccessType::kElements), + divisible_(true), + extent_(WarpShape::kM, WarpShape::kN) { + + auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); + MatrixCoord lane_offset = lane_layout.inverse(lane_id); + + thread_offset_ = { + lane_offset.row() * Shape::kRow, + lane_offset.column() * Policy::kElementsPerAccess + }; + + pointer_ += layout_({ + lane_offset.row() * Shape::kRow, + lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) + }); + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical( + TensorRef const &ref, + TensorCoord const &extent, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / AccessType::kElements), + divisible_(false), + extent_(extent) { + + auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); + MatrixCoord lane_offset = lane_layout.inverse(lane_id); + + thread_offset_ = { + lane_offset.row() * Shape::kRow, + lane_offset.column() * Policy::kElementsPerAccess + }; + + pointer_ += layout_({ + lane_offset.row() * Shape::kRow, + lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) + }); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / AccessType::kElements; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical & add_tile_offset(TensorCoord const &tile_offset) { + + MatrixCoord coord_offset( + tile_offset.row(), + tile_offset.column() * Shape::kColumn + ); + + thread_offset_ += coord_offset; + + pointer_ += layout_({ + coord_offset.row(), + coord_offset.column() + }); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical & operator+=(TensorCoord const &tile_offset) { + + add_tile_offset(tile_offset); + + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + // de-vectorized stores + using ScalarAccessType = AlignedArray; + ScalarAccessType const *scalarFragPtr = reinterpret_cast(&frag); + ScalarAccessType *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::kElementsPerAccess; s++) { + + int ptr_idx = n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s; + int frag_idx = n * Policy::kElementsPerAccess + s; + + int col = thread_offset_.column() + ptr_idx; + + if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { + scalarPointer[ptr_idx] = scalarFragPtr[frag_idx]; + } + } + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + // de-vectorized loads + using ScalarAccessType = AlignedArray; + ScalarAccessType *scalarFragPtr = reinterpret_cast(&frag); + ScalarAccessType const *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::kElementsPerAccess; s++) { + + int ptr_idx = n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s; + int frag_idx = n * Policy::kElementsPerAccess + s; + + int col = thread_offset_.column() + ptr_idx; + + if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { + scalarFragPtr[frag_idx] = scalarPointer[ptr_idx]; + } + } + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical & operator++() { + return add_tile_offset({1, 0}); + } + +}; + + } // namespace warp } // namespace epilogue } // namespace cutlass diff --git a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h index cec0b8f2..53dd8f4a 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h +++ b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h @@ -37,6 +37,11 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +// This is an optimization available on CUDA 11.2 and beyond that eliminates branches in the epilogue. +#define CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED ((__CUDACC_VER_MAJOR__ * 10 + __CUDACC_VER_MINOR__) >= 112) + +///////////////////////////////////////////////////////////////////////////////////////////////// + namespace cutlass { namespace epilogue { namespace warp { @@ -207,13 +212,34 @@ public: AccessType const *frag_ptr = reinterpret_cast(&frag); + AccessType *ptr = pointers_[0]; + +#if CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED + + // When the optimization is enabled, small tiles require separate logic. + if (WarpShape::kN == 32 && warp_column_ > 0) { + ptr = pointers_[1]; + } + +#endif + CUTLASS_PRAGMA_UNROLL for (int64_t n = 0; n < Policy::OperatorCount::kColumn; ++n) { + +#if CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED + // + // When the optimization is enabled, this expression suffices to obtain the SMEM pointer. + // + if (WarpShape::kN == 64) { + ptr = pointers_[n / 4]; + } + +#else + // This is the reference implementation int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess; int ptr_idx = ((column_idx * sizeof_bits::value) / 1024) % Detail::kPointerCount; - AccessType *ptr; if (ptr_idx == 0) { ptr = pointers_[0 % Detail::kPointerCount]; } @@ -226,6 +252,8 @@ public: else if (ptr_idx == 3) { ptr = pointers_[3 % Detail::kPointerCount]; } +#endif + int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess; #if 0 @@ -673,3 +701,7 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// + +#undef CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index c54bdac5..9fdcf126 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -36,6 +36,7 @@ #include "cutlass/cutlass.h" #include "cutlass/uint128.h" #include "cutlass/coord.h" +#include "cutlass/numeric_types.h" /** * \file @@ -50,6 +51,20 @@ namespace cutlass { * Static math utilities ******************************************************************************/ +/// Mixed precision dot product +template +CUTLASS_HOST_DEVICE LongIndex dot( + Coord const &coord, + Coord const &stride, + LongIndex acc = LongIndex()) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < N; ++n) { + acc += LongIndex(coord[n]) * stride[n]; + } + return acc; +} + /** * Statically determine if N is a power-of-two */ @@ -270,12 +285,33 @@ struct FastDivmod { fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right); } + + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + /// + /// Simply returns the quotient + CUTLASS_HOST_DEVICE + int divmod(int &remainder, int dividend) const { + int quotient; + fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right); + return quotient; + } + /// Computes integer division and modulus using precomputed values. This is computationally /// inexpensive. CUTLASS_HOST_DEVICE void operator()(int "ient, int64_t &remainder, int64_t dividend) const { fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right); } + + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + int divmod(int64_t &remainder, int64_t dividend) const { + int quotient; + fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right); + return quotient; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -387,7 +423,7 @@ struct FastDivmodU64 { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Computes the coordinate decomposition from a linear index. +/// Computes the coordinate decomposition from a linear index (64-bit linear index => coord) /// /// This decomposition is accelerated by the FastDivmodU64 object. It is assumed that /// a coordinate of indices can be decomposed by div/mod operations. @@ -428,6 +464,89 @@ CUTLASS_HOST_DEVICE Coord CoordinateDecomposition( return coord; } +/// Computes the coordinate decomposition from a linear index (32-bit linear index => coord) +template +CUTLASS_HOST_DEVICE Coord CoordinateDecomposition( + int linear_idx, ///< Linear index to decompose + FastDivmod const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects + + static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); + + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank; i > 1; --i) { + int remainder; + linear_idx = divmod[i - 2].divmod(remainder, linear_idx); + coord[i - 1] = int(remainder); + } + + coord[0] = int(linear_idx); + + return coord; +} + +template +CUTLASS_HOST_DEVICE Coord CoordinateDecompositionLittleEndian( + uint64_t linear_idx, ///< Linear index to decompose + FastDivmodU64 const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects + + static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); + + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank - 1; ++i) { + uint64_t remainder; + linear_idx = divmod[i].divmod(remainder, linear_idx); + coord[i] = int(remainder); + } + + coord[Rank - 1] = int(linear_idx); + + return coord; +} + +/// Computes the coordinate decomposition from a linear index (32-bit linear index => coord) +template +CUTLASS_HOST_DEVICE Coord CoordinateDecompositionLittleEndian( + int linear_idx, ///< Linear index to decompose + FastDivmod const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects + + static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); + + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank - 1; ++i) { + int remainder; + linear_idx = divmod[i].divmod(remainder, linear_idx); + coord[i] = int(remainder); + } + + coord[Rank - 1] = int(linear_idx); + + return coord; +} + +/// Safely computes the offset of a linear index in bytes for all types +template +CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index) { + + static_assert( + (sizeof_bits::value >= 8 && !(sizeof_bits::value % 8)) || + (sizeof_bits::value < 8 && !(8 % sizeof_bits::value)), + "Size of numeric type in bits must either be divisible by 8 bits, or 8 bits must be divisible by the size."); + + if (sizeof_bits::value >= 8) { + return index * (sizeof_bits::value / 8); + } + else { + int const kElementsPerByte = ((8 / sizeof_bits::value) + ((sizeof_bits::value >= 8) ? 1 : 0)); + return index / kElementsPerByte; + } +} + ///////////////////////////////////////////////////////////////////////////////////////////////// // Min/Max ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -566,6 +685,24 @@ double fast_sqrt(double theta) { #endif } +CUTLASS_HOST_DEVICE +float fast_exp(float x) { + #if defined(__CUDA_ARCH__) + return ::exp(x); + #else + return std::exp(x); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_exp(double x) { + #if defined(__CUDA_ARCH__) + return ::exp(x); + #else + return std::exp(x); + #endif +} + CUTLASS_HOST_DEVICE float fast_log(float x) { #if defined(__CUDA_ARCH__) diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 52d4ca59..5969a236 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -33,6 +33,7 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/complex.h" +#include "cutlass/quaternion.h" #include "cutlass/array.h" #include "cutlass/half.h" @@ -67,6 +68,15 @@ struct multiplies { } }; +template +struct multiplies> { + CUTLASS_HOST_DEVICE + Quaternion operator()(Quaternion lhs, Quaternion const &rhs) const { + lhs = lhs * rhs; + return lhs; + } +}; + /// Squares with optional conversion template struct square { @@ -105,6 +115,23 @@ struct magnitude_squared, Output> { } }; +/// Squares with optional conversion +template +struct magnitude_squared, Output> { + CUTLASS_HOST_DEVICE + Output operator()(Quaternion lhs) const { + multiplies mul_op; + + Output y_w = Output(lhs.w()); + Output y_x = Output(lhs.x()); + Output y_y = Output(lhs.y()); + Output y_z = Output(lhs.z()); + + return mul_op(y_w, y_w) + mul_op(y_x, y_x) + mul_op(y_y, y_y) + \ + mul_op(y_z, y_z); + } +}; + /// Computes the square of a difference with optional conversion template struct square_difference { @@ -1797,6 +1824,52 @@ Array fma(Array const &a, Array const &b, T c) { return op(a, b, c); } + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Quaternion fused multiply-add +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct multiply_add, Quaternion, Quaternion> { + CUTLASS_HOST_DEVICE + Quaternion operator()( + Quaternion const &a, + Quaternion const &b, + Quaternion const &c) const { + + T x = c.x(); + T y = c.y(); + T z = c.z(); + T w = c.w(); + + x += a.w() * b.x(); + x += b.w() * a.x(); + x += a.y() * b.z(); + x += -a.z() * b.y(), + + y += a.w() * b.y(); + y += b.w() * a.y(); + y += a.z() * b.x(); + y += -a.x() * b.z(); + + z += a.w() * b.z(); + z += b.w() * a.z(); + z += a.x() * b.y(); + z += -a.y() * b.x(); + + w += a.w() * b.w(); + w += -a.x() * b.x(); + w += -a.y() * b.y(); + w += -a.z() * b.z(); + + return cutlass::make_Quaternion(x, y, z, w); + + } +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index b398688f..ae364c40 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -446,6 +446,7 @@ public: cudaError_t result; int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -482,7 +483,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace, stream); + Status status = initialize(args, workspace); if (status == Status::kSuccess) { status = run(stream); @@ -673,7 +674,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); } /// Lightweight update given a subset of arguments diff --git a/include/cutlass/gemm/device/gemm_array.h b/include/cutlass/gemm/device/gemm_array.h index be7be25d..284a9771 100644 --- a/include/cutlass/gemm/device/gemm_array.h +++ b/include/cutlass/gemm/device/gemm_array.h @@ -473,7 +473,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace, stream); + Status status = initialize(args, workspace); if (status == Status::kSuccess) { status = run(stream); @@ -700,7 +700,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); } /// Lightweight update given a subset of arguments diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index e1093270..25bba0d0 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -451,7 +451,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace, stream); + Status status = initialize(args, workspace); if (status == Status::kSuccess) { status = run(stream); @@ -666,7 +666,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); } /// Lightweight update given a subset of arguments diff --git a/include/cutlass/gemm/device/gemm_complex.h b/include/cutlass/gemm/device/gemm_complex.h index 4b0fcaa9..68e83508 100644 --- a/include/cutlass/gemm/device/gemm_complex.h +++ b/include/cutlass/gemm/device/gemm_complex.h @@ -465,7 +465,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace, stream); + Status status = initialize(args, workspace); if (status == Status::kSuccess) { status = run(stream); @@ -674,7 +674,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); } /// Lightweight update given a subset of arguments diff --git a/include/cutlass/gemm/device/gemm_sparse.h b/include/cutlass/gemm/device/gemm_sparse.h index b37f585a..99488c49 100644 --- a/include/cutlass/gemm/device/gemm_sparse.h +++ b/include/cutlass/gemm/device/gemm_sparse.h @@ -236,6 +236,7 @@ class SparseGemm { using EpilogueOutputOp = EpilogueOutputOp_; using ThreadblockSwizzle = ThreadblockSwizzle_; using Operator = Operator_; + using MathOperator = Operator; static int const kStages = Stages; static int const kAlignmentA = AlignmentA; static int const kAlignmentB = AlignmentB; diff --git a/include/cutlass/gemm/device/gemm_splitk_parallel.h b/include/cutlass/gemm/device/gemm_splitk_parallel.h index 987319c2..c6f40d77 100644 --- a/include/cutlass/gemm/device/gemm_splitk_parallel.h +++ b/include/cutlass/gemm/device/gemm_splitk_parallel.h @@ -621,7 +621,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index fb541701..78d7fc90 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -121,7 +121,7 @@ public: // warp-level, arch-level (instruction), math operator using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename ArchMmaOperator::Operator; + using MathOperator = typename WarpMmaOperator::MathOperator; // Operator class and arch tag extract bottom-up // set it for top-level gemm device-level template @@ -161,13 +161,11 @@ public: using TensorRefC = TensorRef; using TensorRefD = TensorRef; - using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; - static int const kStages = GemmKernel::Mma::kStages; using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; - using Operator = typename GemmKernel::Operator; using UnderlyingOperator = GemmUniversalBase; using Arguments = typename UnderlyingOperator::Arguments; diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 74c519a4..cb7a899d 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -171,9 +171,11 @@ public: // GEMM K dimension is greater than one. workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); } - + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + return workspace_bytes; } diff --git a/include/cutlass/gemm/device/gemv.h b/include/cutlass/gemm/device/gemv.h new file mode 100644 index 00000000..9c411f8e --- /dev/null +++ b/include/cutlass/gemm/device/gemv.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Gemv { +public: + + using GemvKernel = GemvKernel_; + + + using ElementA = typename GemvKernel::ElementA; + using LayoutA = typename GemvKernel::LayoutA; + using ElementB = typename GemvKernel::ElementB; + using ElementC = typename GemvKernel::ElementC; + + using ElementAccumulator = typename GemvKernel::ElementAccumulator; + using EpilogueOutputOp = typename GemvKernel::EpilogueOutputOp; + + static ComplexTransform const kTransformA = GemvKernel::kTransformA; + static ComplexTransform const kTransformB = GemvKernel::kTransformB; + + static int const kThreadCount = GemvKernel::kThreadCount; + static int const kStages = GemvKernel::kStages; + + static int const kAlignmentA = GemvKernel::kAlignmentA; + static int const kAlignmentB = GemvKernel::kAlignmentB; + static int const kAlignmentC = GemvKernel::kAlignmentC; + + using Arguments = typename GemvKernel::Arguments; + using Params = typename GemvKernel::Params; + +private: + + Params params_; + +public: + + /// Constructs the Gemv. + Gemv() { } + + /// Determines whether the Gemv can execute the given problem. + static Status can_implement(Arguments const &args) { + + return GemvKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return 0; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return dim3((args.problem_size.row() + (kThreadCount - 1)) / kThreadCount, 1, args.batch_count % 65565); + } + + /// Initializes Gemv state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + params_ = Params(args); + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + return params_.update(args); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + dim3 grid = get_grid_shape(params_); + dim3 block(GemvKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemvKernel::SharedStorage)); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm.h b/include/cutlass/gemm/kernel/default_gemm.h index 966b0089..81d05090 100644 --- a/include/cutlass/gemm/kernel/default_gemm.h +++ b/include/cutlass/gemm/kernel/default_gemm.h @@ -111,7 +111,9 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator> + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill = false> struct DefaultGemm; //////////////////////////////////////////////////////////////////////////////// @@ -133,6 +135,8 @@ template < int kAlignmentB, /// Element type for C and D matrix operands typename ElementC, + /// Layout type for C and D matrix operand + typename LayoutC, /// Element type for internal accumulation typename ElementAccumulator, /// Threadblock-level tile size (concept: GemmShape) @@ -151,30 +155,47 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator> + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill> struct DefaultGemm { + Operator, UseZfill> { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, - ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, Stages, - Operator>::ThreadblockMma; + Operator, false, UseZfill>::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; /// Define the epilogue - using Epilogue = + using RegularEpilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount>::Epilogue; + using Affine2Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN< + 2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + using Epilogue = typename cutlass::platform::conditional::value, + RegularEpilogue, + Affine2Epilogue>::type; + /// Define the kernel-level GEMM operator. using GemmKernel = kernel::Gemm; }; + //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Turing Architecture @@ -208,7 +229,9 @@ template < /// If true, kernel is configured to support serial reduction in the epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill > struct DefaultGemm< ElementA, LayoutA, kAlignmentA, @@ -224,7 +247,8 @@ struct DefaultGemm< ThreadblockSwizzle, 2, SplitKSerial, - Operator + Operator, + UseZfill > { /// Define the threadblock-scoped matrix multiply-accumulate @@ -293,14 +317,16 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator> + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill> struct DefaultGemm< ElementA, layout::ColumnMajorInterleaved, kAlignmentA, ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, layout::ColumnMajorInterleaved, int32_t, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, - SplitKSerial, Operator> { + SplitKSerial, Operator, UseZfill> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -312,7 +338,7 @@ struct DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, Stages, Operator, - true>::ThreadblockMma; + true, UseZfill>::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; @@ -356,14 +382,16 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator> + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill> struct DefaultGemm, kAlignmentA, ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, layout::ColumnMajorInterleaved, int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - ThreadblockSwizzle, 2, SplitKSerial, Operator> { + ThreadblockSwizzle, 2, SplitKSerial, Operator, UseZfill> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -390,7 +418,6 @@ struct DefaultGemm, //////////////////////////////////////////////////////////////////////////////// - /// Partial specialization for Volta architecture template < /// Element type for A matrix operand @@ -420,7 +447,9 @@ template < /// If true, kernel is configured to support serial reduction in the epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill > struct DefaultGemm< ElementA, LayoutA, kAlignmentA, @@ -436,7 +465,8 @@ struct DefaultGemm< ThreadblockSwizzle, 2, SplitKSerial, - Operator + Operator, + UseZfill > { /// Define the threadblock-scoped matrix multiply-accumulate @@ -491,6 +521,8 @@ template < int kAlignmentB, /// Element type for C and D matrix operands typename ElementC, + /// Layout type for C and D matrix operand + typename LayoutC, /// Element type for internal accumulation typename ElementAccumulator, /// Tag indicating architecture to tune for @@ -506,7 +538,9 @@ template < /// If true, kernel is configured to support serial reduction in the epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill > struct DefaultGemm< ElementA, @@ -516,7 +550,7 @@ struct DefaultGemm< LayoutB, kAlignmentB, ElementC, - layout::RowMajor, + LayoutC, ElementAccumulator, arch::OpClassSimt, ArchTag, @@ -527,7 +561,13 @@ struct DefaultGemm< ThreadblockSwizzle, 2, SplitKSerial, - Operator> { + Operator, + UseZfill> { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaultMma< ElementA, @@ -537,7 +577,7 @@ struct DefaultGemm< LayoutB, kAlignmentB, ElementAccumulator, - layout::RowMajor, + LayoutC, arch::OpClassSimt, arch::Sm50, ThreadblockShape, @@ -550,13 +590,25 @@ struct DefaultGemm< static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); /// Define the epilogue - using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< + using RegularEpilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< ThreadblockShape, typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess >::Epilogue; + using Affine2Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimtAffineRankN< + 2, + ThreadblockShape, + typename Mma::Operator, + EpilogueOutputOp, + kEpilogueElementsPerAccess + >::Epilogue; + + using Epilogue = typename cutlass::platform::conditional::value, + RegularEpilogue, + Affine2Epilogue>::type; + /// Define the kernel-level GEMM operator. using GemmKernel = kernel::Gemm; }; @@ -579,6 +631,8 @@ template < int kAlignmentB, /// Element type for C and D matrix operands typename ElementC, + /// Layout type for C and D matrix operand + typename LayoutC, /// Element type for internal accumulation typename ElementAccumulator, /// Threadblock-level tile size (concept: GemmShape) @@ -594,7 +648,10 @@ template < /// If true, kernel is configured to support serial reduction in the epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator> + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill +> struct DefaultGemm { + Operator, + UseZfill> { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, - ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm80, + ElementAccumulator, LayoutC, arch::OpClassSimt, arch::Sm80, ThreadblockShape, WarpShape, GemmShape<1, 1, 1>, Stages, - Operator>::ThreadblockMma; + Operator, UseZfill>::ThreadblockMma; static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); /// Define the epilogue - using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< + using RegularEpilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< ThreadblockShape, typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess >::Epilogue; + using Affine2Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimtAffineRankN< + 2, + ThreadblockShape, + typename Mma::Operator, + EpilogueOutputOp, + kEpilogueElementsPerAccess + >::Epilogue; + + using Epilogue = typename cutlass::platform::conditional::value, + RegularEpilogue, + Affine2Epilogue>::type; + /// Define the kernel-level GEMM operator. - using GemmKernel = kernel::Gemm; + using GemmKernel = kernel::Gemm; }; //////////////////////////////////////////////////////////////////////////////// @@ -669,12 +743,15 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator> + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill +> struct DefaultGemm, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, - Operator> { + Operator, UseZfill> { using InstructionShape = GemmShape<1, 1, 4>; using ElementA = int8_t; using ElementB = int8_t; @@ -753,7 +830,10 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator> + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill +> struct DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, @@ -766,7 +846,8 @@ struct DefaultGemm< ThreadblockSwizzle, Stages, SplitKSerial, - Operator> { + Operator, + UseZfill> { /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, @@ -795,6 +876,7 @@ struct DefaultGemm< using GemmKernel = kernel::Gemm; }; //////////////////////////////////////////////////////////////////////////////// + #endif //CUTLASS_ARCH_WMMA_ENABLED //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_universal.h b/include/cutlass/gemm/kernel/default_gemm_universal.h index f9094672..f6fd6e65 100644 --- a/include/cutlass/gemm/kernel/default_gemm_universal.h +++ b/include/cutlass/gemm/kernel/default_gemm_universal.h @@ -95,6 +95,8 @@ template < int Stages, /// Operation performed by GEMM typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill = false, /// typename Enable = void > @@ -141,7 +143,10 @@ template < /// Number of stages used in the pipelined mainloop int Stages, /// Operation performed by GEMM - typename Operator> + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill +> struct DefaultGemmUniversal< ElementA, LayoutA, @@ -163,6 +168,7 @@ struct DefaultGemmUniversal< ThreadblockSwizzle, Stages, Operator, + UseZfill, typename std::enable_if< ! cutlass::is_complex::value>::type > { @@ -185,13 +191,14 @@ struct DefaultGemmUniversal< ThreadblockSwizzle, Stages, true, - Operator + Operator, + UseZfill >::GemmKernel; /// Define the kernel in terms of the default kernel using GemmKernel = kernel::GemmUniversal< typename DefaultGemmKernel::Mma, - typename DefaultGemmKernel::Epilogue, + typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle >; }; @@ -242,7 +249,9 @@ template < /// Number of stages used in the pipelined mainloop int Stages, /// Operation performed by GEMM - typename Operator + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill > struct DefaultGemmUniversal< ElementA, @@ -265,6 +274,7 @@ struct DefaultGemmUniversal< ThreadblockSwizzle, Stages, Operator, + UseZfill, typename std::enable_if::value>::type > { diff --git a/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h b/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h new file mode 100644 index 00000000..8e6a8fc0 --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h @@ -0,0 +1,237 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/kernel/gemm_with_fused_epilogue.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// + typename Enable = void +> +struct DefaultGemmWithBroadcast { + + using GemmBase = typename DefaultGemmUniversal< + ElementA_, LayoutA_, TransformA, kAlignmentA, + ElementB_, LayoutB_, TransformB, kAlignmentB, + ElementC_, LayoutC_, ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator + >::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + ElementC_, + EpilogueOutputOp, + GemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue< + typename GemmBase::Mma, + Epilogue, + ThreadblockSwizzle + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parital specialization: ArchTag = cutlass::arch::Sm70 +/// +/// +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// + typename Enable +> +struct DefaultGemmWithBroadcast< + ElementA_, LayoutA_, TransformA, kAlignmentA, + ElementB_, LayoutB_, TransformB, kAlignmentB, + ElementC_, LayoutC_, + ElementAccumulator, + OperatorClass, + cutlass::arch::Sm70, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator, + Enable + > { + + using GemmBase = typename DefaultGemmUniversal< + ElementA_, LayoutA_, TransformA, kAlignmentA, + ElementB_, LayoutB_, TransformB, kAlignmentB, + ElementC_, LayoutC_, ElementAccumulator, + OperatorClass, + cutlass::arch::Sm70, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator + >::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + ElementC_, + EpilogueOutputOp, + GemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue< + typename GemmBase::Mma, + Epilogue, + ThreadblockSwizzle + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h b/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h new file mode 100644 index 00000000..ee32b7bb --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h @@ -0,0 +1,144 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_with_k_reduction.h" +#include "cutlass/gemm/threadblock/default_mma_with_reduction.h" +#include "cutlass/gemm/threadblock/default_mma_core_with_reduction.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// + bool ReduceKForA_, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill = false, + /// + typename Enable = void> +struct DefaultGemmWithKReduction { + + static const bool kReduceKForA = (platform::is_same::value) ? ReduceKForA_ : !ReduceKForA_; + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMmaWithReduction< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kReduceKForA, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator, false, UseZfill>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + /// Define the epilogue + using EpilogueGemmKReduction = + typename cutlass::epilogue::threadblock::EpilogueGemmKReduction< + ElementAccumulator, ElementC, ThreadblockShape, typename Mma::Operator, kReduceKForA>; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::GemmWithKReduction; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_with_reduction.h b/include/cutlass/gemm/kernel/default_gemm_with_reduction.h index 47c075c9..0aab9c87 100644 --- a/include/cutlass/gemm/kernel/default_gemm_with_reduction.h +++ b/include/cutlass/gemm/kernel/default_gemm_with_reduction.h @@ -107,7 +107,8 @@ struct DefaultGemmWithReduction { EpilogueOutputOp, ThreadblockSwizzle, Stages, - Operator + Operator, + true >::GemmKernel; // Replace epilogue @@ -129,7 +130,6 @@ struct DefaultGemmWithReduction { >; }; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Parital specialization: ArchTag = cutlass::arch::Sm70 diff --git a/include/cutlass/gemm/kernel/gemm.h b/include/cutlass/gemm/kernel/gemm.h index 1d5601cd..0acc9084 100644 --- a/include/cutlass/gemm/kernel/gemm.h +++ b/include/cutlass/gemm/kernel/gemm.h @@ -65,6 +65,7 @@ struct Gemm { struct Params { cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; typename Mma::IteratorA::Params params_A; typename Mma::IteratorA::TensorRef ref_A; typename Mma::IteratorB::Params params_B; @@ -83,7 +84,7 @@ struct Gemm { // CUTLASS_HOST_DEVICE - Params(): semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } CUTLASS_HOST_DEVICE Params( @@ -98,6 +99,7 @@ struct Gemm { ): problem_size(problem_size), grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A(ref_A.layout()), ref_A(ref_A), params_B(ref_B.layout()), @@ -188,7 +190,7 @@ struct Gemm { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -266,7 +268,7 @@ struct Gemm { // threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_array.h b/include/cutlass/gemm/kernel/gemm_array.h index 0df21742..83fd071e 100644 --- a/include/cutlass/gemm/kernel/gemm_array.h +++ b/include/cutlass/gemm/kernel/gemm_array.h @@ -61,6 +61,7 @@ struct GemmArray { struct Params { cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; typename Mma::IteratorA::Params params_A; typename Mma::IteratorA::Element const * const * ptr_A; typename Mma::IteratorB::Params params_B; @@ -79,7 +80,8 @@ struct GemmArray { // CUTLASS_HOST_DEVICE - Params() { } + Params() : + swizzle_log_tile(0) { } CUTLASS_HOST_DEVICE Params( @@ -98,6 +100,7 @@ struct GemmArray { ): problem_size(problem_size_), grid_tiled_shape(grid_tiled_shape_), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A(layout_A), ptr_A(ptr_A_), params_B(layout_B), @@ -134,7 +137,7 @@ struct GemmArray { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -209,7 +212,7 @@ struct GemmArray { // threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_batched.h b/include/cutlass/gemm/kernel/gemm_batched.h index ceefed12..36dd2501 100644 --- a/include/cutlass/gemm/kernel/gemm_batched.h +++ b/include/cutlass/gemm/kernel/gemm_batched.h @@ -61,6 +61,7 @@ struct GemmBatched { struct Params { cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; typename Mma::IteratorA::Params params_A; typename Mma::IteratorA::TensorRef ref_A; int64_t stride_A; @@ -82,7 +83,7 @@ struct GemmBatched { // CUTLASS_HOST_DEVICE - Params() { } + Params() : swizzle_log_tile(0) { } CUTLASS_HOST_DEVICE Params( @@ -101,6 +102,7 @@ struct GemmBatched { ): problem_size(problem_size_), grid_tiled_shape(grid_tiled_shape_), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A(ref_A_.layout()), ref_A(ref_A_), stride_A(stride_A_), @@ -141,7 +143,7 @@ struct GemmBatched { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -221,7 +223,7 @@ struct GemmBatched { // threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_params.h b/include/cutlass/gemm/kernel/gemm_params.h new file mode 100755 index 00000000..5c8ea7bf --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_params.h @@ -0,0 +1,193 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct GemmParams { + + // + // Type definitions + // + using Index = int32_t; + using LongIndex = int64_t; + + using MmaIteratorParams = typename cutlass::transform::threadblock::PredicatedTileAccessIteratorParams; + using EpilogueIteratorParams = typename cutlass::epilogue::threadblock::PredicatedTileIteratorParams; + + // + // Data members + // + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + // Data members for Mma::Iterator::Params + MmaIteratorParams params_itr_a; + MmaIteratorParams params_itr_b; + + // Data member for Epilogue::OutputTileIterator::Params + EpilogueIteratorParams params_itr_c; + EpilogueIteratorParams params_itr_d; + + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + + LongIndex lda; + LongIndex ldb; + LongIndex ldc; + LongIndex ldd; + + LongIndex batch_stride_A; + LongIndex batch_stride_B; + LongIndex batch_stride_C; + LongIndex batch_stride_D; + + int *semaphore; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmParams() {} + + CUTLASS_HOST_DEVICE + GemmParams( + cutlass::gemm::GemmCoord problem_size_, + cutlass::gemm::GemmCoord grid_tiled_shape_, + int swizzle_log_tile_, + GemmUniversalMode mode_, + int batch_count_, + int gemm_k_size_, + void const * ptr_A_, + void const * ptr_B_, + void const * ptr_C_, + void * ptr_D_, + LongIndex lda_, + LongIndex ldb_, + LongIndex ldc_, + LongIndex ldd_, + int64_t batch_stride_A_, + int64_t batch_stride_B_, + int64_t batch_stride_C_, + int64_t batch_stride_D_, + MmaIteratorParams const & params_itr_a_, + MmaIteratorParams const & params_itr_b_, + EpilogueIteratorParams const & params_itr_c_, + EpilogueIteratorParams const & params_itr_d_, + void *workspace_ = nullptr) : + problem_size(problem_size_), + grid_tiled_shape(grid_tiled_shape_), + swizzle_log_tile(swizzle_log_tile_), + mode(mode_), + batch_count(batch_count_), + gemm_k_size(gemm_k_size_), + ptr_A(const_cast(ptr_A_)), + ptr_B(const_cast(ptr_B_)), + ptr_C(const_cast(ptr_C_)), + ptr_D(ptr_D_), + lda(lda_), + ldb(ldb_), + ldc(ldc_), + ldd(ldd_), + batch_stride_A(batch_stride_A_), + batch_stride_B(batch_stride_B_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_), + params_itr_a(params_itr_a_), + params_itr_b(params_itr_b_), + params_itr_c(params_itr_c_), + params_itr_d(params_itr_d_), + semaphore(static_cast(workspace_) + ) { } + + + CUTLASS_HOST_DEVICE + void update( + void const * ptr_A_, + void const * ptr_B_, + void const * ptr_C_, + void * ptr_D_, + int64_t batch_stride_A_, + int64_t batch_stride_B_, + int64_t batch_stride_C_, + int64_t batch_stride_D_, + void *workspace_ = nullptr) { + + ptr_A = const_cast(ptr_A_); + ptr_B = const_cast(ptr_B_); + ptr_C = const_cast(ptr_C_); + ptr_D = ptr_D_; + + batch_stride_A = batch_stride_A_; + batch_stride_B = batch_stride_B_; + batch_stride_C = batch_stride_C_; + batch_stride_D = batch_stride_D_; + + + semaphore = static_cast(workspace_); + CUTLASS_TRACE_HOST("GemmParams::update()"); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_pipelined.h b/include/cutlass/gemm/kernel/gemm_pipelined.h index 39f328a3..587186e6 100644 --- a/include/cutlass/gemm/kernel/gemm_pipelined.h +++ b/include/cutlass/gemm/kernel/gemm_pipelined.h @@ -66,7 +66,9 @@ __global__ void GemmPipelined( // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape); + int swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); + + cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset(swizzle_log_tile); if (grid_tiled_shape.m() <= tb_tile_offset.m() || grid_tiled_shape.n() <= tb_tile_offset.n()) { @@ -131,7 +133,7 @@ __global__ void GemmPipelined( warp_id, lane_id); - tb_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape); + tb_tile_offset = threadblock_swizzle.get_tile_offset(swizzle_log_tile); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex.h b/include/cutlass/gemm/kernel/gemm_planar_complex.h index 0151848f..f03f7a53 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex.h @@ -123,14 +123,14 @@ public: void * ptr_D_real; void * ptr_D_imag; - int lda_real; - int lda_imag; - int ldb_real; - int ldb_imag; - int ldc_real; - int ldc_imag; - int ldd_real; - int ldd_imag; + typename LayoutA::Stride::Index lda_real; + typename LayoutA::Stride::Index lda_imag; + typename LayoutB::Stride::Index ldb_real; + typename LayoutB::Stride::Index ldb_imag; + typename LayoutC::Stride::Index ldc_real; + typename LayoutC::Stride::Index ldc_imag; + typename LayoutC::Stride::Index ldd_real; + typename LayoutC::Stride::Index ldd_imag; int64_t batch_stride_A; int64_t batch_stride_A_imag; @@ -173,14 +173,14 @@ public: void const * ptr_C_imag, void * ptr_D_real, void * ptr_D_imag, - int lda_real, - int lda_imag, - int ldb_real, - int ldb_imag, - int ldc_real, - int ldc_imag, - int ldd_real, - int ldd_imag, + typename LayoutA::Stride::Index lda_real, + typename LayoutA::Stride::Index lda_imag, + typename LayoutB::Stride::Index ldb_real, + typename LayoutB::Stride::Index ldb_imag, + typename LayoutC::Stride::Index ldc_real, + typename LayoutC::Stride::Index ldc_imag, + typename LayoutC::Stride::Index ldd_real, + typename LayoutC::Stride::Index ldd_imag, int64_t batch_stride_A = 0, int64_t batch_stride_A_imag = 0, int64_t batch_stride_B = 0, @@ -245,6 +245,7 @@ public: struct Params { cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; typename Mma::IteratorA::Params params_A_real; typename Mma::IteratorA::Params params_A_imag; @@ -289,6 +290,7 @@ public: Params(): batch_count(0), gemm_k_size(0), + swizzle_log_tile(0), mode(cutlass::gemm::GemmUniversalMode::kGemm), ptr_A_real(nullptr), ptr_A_imag(nullptr), @@ -317,6 +319,7 @@ public: ): problem_size(args.problem_size), grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A_real(args.lda_real), params_A_imag(args.lda_imag), params_B_real(args.ldb_real), @@ -412,6 +415,12 @@ public: return Status::kSuccess; } + static size_t get_extra_workspace_size(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { @@ -420,7 +429,7 @@ public: ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -551,7 +560,7 @@ public: // threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h index 05bde223..52ee205a 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h @@ -127,14 +127,14 @@ public: void * const * ptr_D_real; void * const * ptr_D_imag; - int lda_real; - int lda_imag; - int ldb_real; - int ldb_imag; - int ldc_real; - int ldc_imag; - int ldd_real; - int ldd_imag; + typename LayoutA::Stride::Index lda_real; + typename LayoutA::Stride::Index lda_imag; + typename LayoutB::Stride::Index ldb_real; + typename LayoutB::Stride::Index ldb_imag; + typename LayoutC::Stride::Index ldc_real; + typename LayoutC::Stride::Index ldc_imag; + typename LayoutC::Stride::Index ldd_real; + typename LayoutC::Stride::Index ldd_imag; int64_t batch_stride_D; // unused @@ -175,14 +175,14 @@ public: void const * const * ptr_C_imag, void * const * ptr_D_real, void * const * ptr_D_imag, - int lda_real, - int lda_imag, - int ldb_real, - int ldb_imag, - int ldc_real, - int ldc_imag, - int ldd_real, - int ldd_imag + typename LayoutA::Stride::Index lda_real, + typename LayoutA::Stride::Index lda_imag, + typename LayoutB::Stride::Index ldb_real, + typename LayoutB::Stride::Index ldb_imag, + typename LayoutC::Stride::Index ldc_real, + typename LayoutC::Stride::Index ldc_imag, + typename LayoutC::Stride::Index ldd_real, + typename LayoutC::Stride::Index ldd_imag ): mode(GemmUniversalMode::kArray), problem_size(problem_size), @@ -234,7 +234,7 @@ public: struct Params { cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; - + int swizzle_log_tile; typename Mma::IteratorA::Params params_A_real; typename Mma::IteratorA::Params params_A_imag; typename Mma::IteratorB::Params params_B_real; @@ -268,6 +268,7 @@ public: CUTLASS_HOST_DEVICE Params(): batch_count(0), + swizzle_log_tile(0), ptr_M(nullptr), ptr_N(nullptr), ptr_K(nullptr), @@ -289,6 +290,7 @@ public: ): problem_size(args.problem_size), grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), ptr_M(args.ptr_M), ptr_N(args.ptr_N), ptr_K(args.ptr_K), @@ -369,6 +371,12 @@ public: return Status::kSuccess; } + static size_t get_extra_workspace_size(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { @@ -377,7 +385,7 @@ public: ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || diff --git a/include/cutlass/gemm/kernel/gemm_splitk_parallel.h b/include/cutlass/gemm/kernel/gemm_splitk_parallel.h index e009567e..11ab74fd 100644 --- a/include/cutlass/gemm/kernel/gemm_splitk_parallel.h +++ b/include/cutlass/gemm/kernel/gemm_splitk_parallel.h @@ -63,6 +63,7 @@ struct GemmSplitKParallel { struct Params { cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; typename Mma::IteratorA::Params params_A; typename Mma::IteratorA::TensorRef ref_A; typename Mma::IteratorB::Params params_B; @@ -78,7 +79,7 @@ struct GemmSplitKParallel { // CUTLASS_HOST_DEVICE - Params() { } + Params(): swizzle_log_tile(0) { } CUTLASS_HOST_DEVICE Params( @@ -92,6 +93,7 @@ struct GemmSplitKParallel { ): problem_size(problem_size), grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A(ref_A.layout()), ref_A(ref_A), params_B(ref_B.layout()), @@ -129,7 +131,7 @@ struct GemmSplitKParallel { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -207,7 +209,7 @@ struct GemmSplitKParallel { // threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); //assume identity swizzle MatrixCoord threadblock_offset( @@ -243,4 +245,3 @@ struct GemmSplitKParallel { } // namespace kernel } // namespace gemm } // namespace cutlass - diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index 0ff5ce99..0b590f1c 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -115,10 +115,15 @@ public: int64_t batch_stride_C; int64_t batch_stride_D; - int lda; - int ldb; - int ldc; - int ldd; + typename LayoutA::Stride stride_a; + typename LayoutB::Stride stride_b; + typename LayoutC::Stride stride_c; + typename LayoutC::Stride stride_d; + + typename LayoutA::Stride::LongIndex lda; + typename LayoutB::Stride::LongIndex ldb; + typename LayoutC::Stride::LongIndex ldc; + typename LayoutC::Stride::LongIndex ldd; // // Methods @@ -143,10 +148,10 @@ public: int64_t batch_stride_B, int64_t batch_stride_C, int64_t batch_stride_D, - int lda, - int ldb, - int ldc, - int ldd + typename LayoutA::Stride stride_a, + typename LayoutB::Stride stride_b, + typename LayoutC::Stride stride_c, + typename LayoutC::Stride stride_d ): mode(mode), problem_size(problem_size), @@ -154,11 +159,44 @@ public: epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), - lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { + stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d) { CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); } + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + typename LayoutA::Stride::LongIndex lda, + typename LayoutB::Stride::LongIndex ldb, + typename LayoutC::Stride::LongIndex ldc, + typename LayoutC::Stride::LongIndex ldd + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { + stride_a = make_Coord(lda); + stride_b = make_Coord(ldb); + stride_c = make_Coord(ldc); + stride_d = make_Coord(ldd); + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + } + /// Returns arguments for the transposed problem Arguments transposed_problem() const { Arguments args(*this); @@ -166,6 +204,7 @@ public: std::swap(args.problem_size.m(), args.problem_size.n()); std::swap(args.ptr_A, args.ptr_B); std::swap(args.lda, args.ldb); + std::swap(args.stride_a, args.stride_b); std::swap(args.batch_stride_A, args.batch_stride_B); return args; @@ -181,6 +220,7 @@ public: cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; @@ -211,6 +251,7 @@ public: CUTLASS_HOST_DEVICE Params(): + swizzle_log_tile(0), params_A(0), params_B(0), params_C(0), @@ -237,10 +278,11 @@ public: ): problem_size(args.problem_size), grid_tiled_shape(grid_tiled_shape), - params_A(args.lda), - params_B(args.ldb), - params_C(args.ldc), - params_D(args.ldd), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), + params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), + params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), + params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), output_op(args.epilogue), mode(args.mode), batch_count(args.batch_count), @@ -276,7 +318,6 @@ public: output_op = args.epilogue; semaphore = static_cast(workspace); - CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); } }; @@ -335,6 +376,12 @@ public: return can_implement(args.problem_size); } + static size_t get_extra_workspace_size(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { @@ -343,7 +390,7 @@ public: ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -393,7 +440,6 @@ public: threadblock_tile_offset.n() * Mma::Shape::kN }; - // Compute position within threadblock int thread_idx = threadIdx.x; @@ -450,8 +496,7 @@ public: // Masked tile iterators constructed from members // - threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h new file mode 100644 index 00000000..c483beaa --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h @@ -0,0 +1,735 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Gemm kernel with fused reduction operation. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmWithFusedEpilogue { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value + ); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + + void * ptr_Vector; + void * ptr_Tensor; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldd; + typename LayoutC::Stride::Index ldr; + typename LayoutC::Stride::Index ldt; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + void * ptr_Vector, + void * ptr_Tensor, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int64_t batch_stride_Vector, + int64_t batch_stride_Tensor, + typename LayoutA::Stride::Index lda, + typename LayoutB::Stride::Index ldb, + typename LayoutC::Stride::Index ldc, + typename LayoutC::Stride::Index ldd, + typename LayoutC::Stride::Index ldr, + typename LayoutC::Stride::Index ldt + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D), + batch_stride_Vector(batch_stride_Vector), + batch_stride_Tensor(batch_stride_Tensor), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt) + { + CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + CUTLASS_TRACE_HOST(" ldt: " << this->ldt); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.batch_stride_A, args.batch_stride_B); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::TensorTileIterator::Params params_Tensor; + + typename EpilogueOutputOp::Params output_op; + + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + + void * ptr_Vector; + typename LayoutC::Stride::Index ldr; + + void * ptr_Tensor; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + int *semaphore; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_Vector(nullptr), + ldr(0), + ptr_Tensor(nullptr), + batch_stride_A(0), + batch_stride_B(0), + batch_stride_C(0), + batch_stride_D(0), + batch_stride_Vector(0), + batch_stride_Tensor(0), + semaphore(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size, + void *workspace = nullptr + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + params_D(args.ldd), + params_Tensor(args.ldt), + output_op(args.epilogue), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(gemm_k_size), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr), + ptr_Tensor(args.ptr_Tensor), + + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + batch_stride_Vector(args.batch_stride_Vector), + batch_stride_Tensor(args.batch_stride_Tensor), + + semaphore(static_cast(workspace)) { + + CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + CUTLASS_TRACE_HOST(" ldt: " << args.ldt); + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + + ptr_Vector = args.ptr_Vector; + ldr = args.ldr; + ptr_Tensor = args.ptr_Tensor; + + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_D = args.batch_stride_D; + batch_stride_Vector = args.batch_stride_Vector; + batch_stride_Tensor = args.batch_stride_Tensor; + + output_op = args.epilogue; + + semaphore = static_cast(workspace); + + CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmWithFusedEpilogue() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || + (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || + (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { + + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + + #define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + + #if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + #endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + // + // Fetch pointers based on mode. + // + + // + // Special path when split-K not enabled. + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + // Only the final block outputs Tensor + ptr_Tensor, + params.problem_size.mn(), + thread_idx, + threadblock_offset); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + ptr_Vector, + iterator_D, + accumulators, + iterator_C, + tensor_iterator, + params.problem_size.mn(), + threadblock_offset); + + return; + } + + // + // Slower path when split-K or batching is needed + // + + + #if SPLIT_K_ENABLED + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + if (ptr_Tensor) { + ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; + } + if (ptr_Vector) { + ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; + } + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + ptr_Tensor = static_cast(params.ptr_Tensor)[threadblock_tile_offset.k()]; + ptr_Vector = static_cast(params.ptr_Vector)[threadblock_tile_offset.k()]; + } + #endif + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + // Only the final block outputs Tensor + ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) + ? nullptr + : ptr_Tensor, + params.problem_size.mn(), + thread_idx, + threadblock_offset); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + #if SPLIT_K_ENABLED + // Wait on the semaphore - this latency may have been covered by iterator construction + if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + #endif + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + // Only the final block uses Vector + ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) + ? nullptr + : ptr_Vector, + iterator_D, + accumulators, + iterator_C, + tensor_iterator, + params.problem_size.mn(), + threadblock_offset); + + // + // Release the semaphore + // + + #if SPLIT_K_ENABLED + if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + #endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h new file mode 100644 index 00000000..b24631dd --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h @@ -0,0 +1,649 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename EpilogueGemmKReduction_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmWithKReduction { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using EpilogueGemmKReduction = EpilogueGemmKReduction_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using LayoutGemmKReduction = cutlass::layout::PitchLinear; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + static int const kReduceKForA = Mma::kReduceKForA; + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + void * ptr_gemm_k_reduction; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_gemm_k_reduction; + + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldd; + typename LayoutGemmKReduction::Stride::Index ld_gemm_k_reduction; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), ptr_gemm_k_reduction(nullptr) { } + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + void * ptr_gemm_k_reduction, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int64_t batch_stride_gemm_k_reduction, + typename LayoutA::Stride::Index lda, + typename LayoutB::Stride::Index ldb, + typename LayoutC::Stride::Index ldc, + typename LayoutC::Stride::Index ldd, + typename LayoutGemmKReduction::Stride::Index ld_gemm_k_reduction + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), ptr_gemm_k_reduction(ptr_gemm_k_reduction), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), batch_stride_gemm_k_reduction(batch_stride_gemm_k_reduction), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ld_gemm_k_reduction(ld_gemm_k_reduction) { + + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.batch_stride_A, args.batch_stride_B); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + + typename EpilogueOutputOp::Params output_op; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + void * ptr_gemm_k_reduction; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_gemm_k_reduction; + + int *semaphore; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_gemm_k_reduction(nullptr), + batch_stride_A(0), + batch_stride_B(0), + batch_stride_C(0), + batch_stride_D(0), + batch_stride_gemm_k_reduction(0), + semaphore(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size, + void *workspace = nullptr + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + params_D(args.ldd), + output_op(args.epilogue), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(gemm_k_size), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + batch_stride_gemm_k_reduction(args.batch_stride_gemm_k_reduction), + semaphore(static_cast(workspace)) { + + CUTLASS_TRACE_HOST("GemmUniversal::Params::Params() - problem_size: " << problem_size); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D = workspace; + ptr_gemm_k_reduction = static_cast(workspace) + + sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } else { + ptr_D = args.ptr_D; + ptr_gemm_k_reduction = args.ptr_gemm_k_reduction; + } + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + ptr_gemm_k_reduction = args.ptr_gemm_k_reduction; + + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_D = args.batch_stride_D; + batch_stride_gemm_k_reduction = args.batch_stride_gemm_k_reduction; + + output_op = args.epilogue; + + semaphore = static_cast(workspace); + CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmWithKReduction() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || + (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || + (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { + + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + size_t workspace_bytes = 0; + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + // Split-K parallel always requires a temporary workspace + workspace_bytes = + sizeof(ElementC) * + size_t(args.batch_stride_gemm_k_reduction) * + size_t(grid_tiled_shape.k()); + } + + return workspace_bytes; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + + __syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + typename Mma::FragmentReduction gemm_k_accumulators; + + gemm_k_accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators, + gemm_k_accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + ElementC *ptr_gemm_k_reduction = static_cast(params.ptr_gemm_k_reduction); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_gemm_k_reduction += threadblock_tile_offset.k() * params.batch_stride_gemm_k_reduction; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + if ((kReduceKForA && threadblock_tile_offset.n() == 0) + || (!kReduceKForA && threadblock_tile_offset.m() == 0)) { + + int warp_idx_mn = warp_idx % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Mma::Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Mma::Base::WarpCount::kM; + + if ((kReduceKForA && warp_idx_n == 0) + || (!kReduceKForA && warp_idx_m == 0)) { + + int reduction_warp_idx = kReduceKForA ? warp_idx_m : warp_idx_n; + int reduction_threadblock_offset = kReduceKForA ? threadblock_tile_offset.m() : + threadblock_tile_offset.n(); + int reduction_vector_size = kReduceKForA ? params.problem_size.m() + : params.problem_size.n(); + EpilogueGemmKReduction epilogue_gemm_k_reduction(thread_idx, + reduction_warp_idx, + lane_idx, + reduction_threadblock_offset, + ptr_gemm_k_reduction); + epilogue_gemm_k_reduction( + reduction_vector_size, + gemm_k_accumulators, + params.mode == GemmUniversalMode::kGemm + && (params.grid_tiled_shape.k() > 1) + && (threadblock_tile_offset.k() > 0)); + } + } + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemv.h b/include/cutlass/gemm/kernel/gemv.h new file mode 100644 index 00000000..29f49817 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemv.h @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename ElementC_, + typename ElementAccumulator_, + typename EpilogueOutputOp_ +> +struct Gemv { +public: + + using ElementA = ElementA_; + using LayoutA = layout::ColumnMajor; + using TensorRefA = TensorRef; + + static_assert(std::is_same::value, + "Only supported for column-major A matrix"); + + using ElementB = ElementB_; + using ElementC = ElementC_; + + using ElementAccumulator = ElementAccumulator_; + using EpilogueOutputOp = EpilogueOutputOp_; + + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + static int const kThreadCount = 32; + static int const kStages = 1; + + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + static int const kAlignmentC = 1; + + // + // Structures + // + + /// Argument structure + struct Arguments { + MatrixCoord problem_size; + int32_t batch_count; + typename EpilogueOutputOp::Params output_op; + + TensorRefA ref_A; + + ElementB const *ptr_B; + ElementC const *ptr_C; + ElementC *ptr_D; + + int64_t inc_B; + int64_t inc_C; + int64_t inc_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + + Arguments(): batch_count(0) { } + + Arguments( + MatrixCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params output_op, + TensorRefA ref_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t inc_B, + int64_t inc_C, + int64_t inc_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D + ): + problem_size(problem_size), + batch_count(batch_count), + output_op(output_op), + ref_A(ref_A), + ptr_B(static_cast(ptr_B)), + ptr_C(static_cast(ptr_C)), + ptr_D(static_cast(ptr_D)), + inc_B(inc_B), + inc_C(inc_C), + inc_D(inc_D), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) + { } + + Arguments( + MatrixCoord problem_size, + typename EpilogueOutputOp::Params output_op, + TensorRefA ref_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t inc_B, + int64_t inc_C, + int64_t inc_D + ): + Arguments( + problem_size, + 1, + output_op, + ref_A, + ptr_B, + ptr_C, + ptr_D, + inc_B, + inc_C, + inc_D, + 1, + 1, + 1, + 1) + { } + + Status update(Arguments const &args) { + output_op = args.output_op; + ref_A = ref_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + + return Status::kSuccess; + } + }; + + using Params = Arguments; + + /// Shared memory storage structure + union SharedStorage { + + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + Gemv() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::MatrixCoord const & problem_size) { + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Loop over batch indices + for (int batch_idx = blockIdx.z; batch_idx < params.batch_count; batch_idx += gridDim.z) { + + int i = blockIdx.x * kThreadCount + threadIdx.x; + + ElementA const *ptr_A = params.ref_A.data() + i; + ElementB const *ptr_B = params.ptr_B; + + ptr_A += batch_idx * params.batch_stride_A; + ptr_B += batch_idx * params.batch_stride_B; + + ElementAccumulator accum = ElementAccumulator(); + + // Compute inner product + CUTLASS_PRAGMA_NO_UNROLL + for (int k = 0; k < params.problem_size.column(); ++k) { + + // Fetch from A + ElementA a = ElementA(); + if (i < params.problem_size.row()) { + a = *ptr_A; + } + ptr_A += params.ref_A.stride(0); + + // Fetch from B + ElementB b = *ptr_B; + ptr_B += params.inc_B; + + // Math + accum += ElementAccumulator(a) * ElementAccumulator(b); + } + + // + // Epilogue phase + // + + ElementC const *ptr_C = params.ptr_C + i * params.inc_C + batch_idx * params.batch_stride_C; + ElementC *ptr_D = params.ptr_D + i * params.inc_D + batch_idx * params.batch_stride_D; + + EpilogueOutputOp output_op(params.output_op); + + typename EpilogueOutputOp::FragmentAccumulator accum_fragment; + typename EpilogueOutputOp::FragmentOutput source_fragment; + typename EpilogueOutputOp::FragmentOutput output_fragment; + + accum_fragment[0] = accum; + + if (i < params.problem_size.row()) { + if (output_op.is_source_needed()) { + source_fragment[0] = *ptr_C; + output_fragment = output_op(accum_fragment, source_fragment); + } + else { + output_fragment = output_op(accum_fragment); + } + + *ptr_D = output_fragment[0]; + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/sparse_gemm.h b/include/cutlass/gemm/kernel/sparse_gemm.h index 9d9e0a28..c86cda40 100644 --- a/include/cutlass/gemm/kernel/sparse_gemm.h +++ b/include/cutlass/gemm/kernel/sparse_gemm.h @@ -72,6 +72,7 @@ struct SparseGemm { struct Params { cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; typename Mma::IteratorA::Params params_A; typename Mma::IteratorA::TensorRef ref_A; typename Mma::IteratorB::Params params_B; @@ -92,7 +93,7 @@ struct SparseGemm { // CUTLASS_HOST_DEVICE - Params(): semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } CUTLASS_HOST_DEVICE Params( @@ -108,6 +109,7 @@ struct SparseGemm { ): problem_size(problem_size), grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A(ref_A.layout()), ref_A(ref_A), params_B(ref_B.layout()), @@ -210,7 +212,7 @@ struct SparseGemm { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -299,7 +301,7 @@ struct SparseGemm { // threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/thread/mma_sm50.h b/include/cutlass/gemm/thread/mma_sm50.h index e7bbbc90..fb3bad29 100644 --- a/include/cutlass/gemm/thread/mma_sm50.h +++ b/include/cutlass/gemm/thread/mma_sm50.h @@ -125,7 +125,7 @@ struct MmaGeneric { reinterpret_cast(&B), LayoutB::packed({Shape::kK, Shape::kN})); TensorRef d_ref( - reinterpret_cast(&D), LayoutC::packed({ Shape::kM, Shape::kN })); + reinterpret_cast(&D), LayoutC::packed(make_Coord(Shape::kM, Shape::kN))); MmaOp mma_op; diff --git a/include/cutlass/gemm/thread/mma_sm60.h b/include/cutlass/gemm/thread/mma_sm60.h index 839e07a7..eb1a3c33 100644 --- a/include/cutlass/gemm/thread/mma_sm60.h +++ b/include/cutlass/gemm/thread/mma_sm60.h @@ -79,13 +79,9 @@ struct Mma_HFMA2 < true > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; - static_assert( - !(Shape::kM % 2), - "Mma_HFMA2 requires the M dimension to be divisible by 2." - ); - /// A operand storage using FragmentA = Array; @@ -98,6 +94,11 @@ struct Mma_HFMA2 < /// Underlying mathematical operator using Operator = arch::OpMultiplyAdd; + static_assert( + !(Shape::kM % 2), + "Mma_HFMA2 requires the M dimension to be divisible by 2." + ); + // // Methods // @@ -170,13 +171,9 @@ struct Mma_HFMA2< true > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; - static_assert( - !(Shape::kN % 2), - "Mma_HFMA2 requires the N dimension to be divisible by 2." - ); - /// A operand storage using FragmentA = Array; @@ -189,6 +186,11 @@ struct Mma_HFMA2< /// Underlying mathematical operator using Operator = arch::OpMultiplyAdd; + static_assert( + !(Shape::kN % 2), + "Mma_HFMA2 requires the N dimension to be divisible by 2." + ); + // // Methods // @@ -266,13 +268,9 @@ struct Mma_HFMA2 < true > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; - static_assert( - !(Shape::kM % 2), - "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2." - ); - /// A operand storage using FragmentA = Array; @@ -285,6 +283,11 @@ struct Mma_HFMA2 < /// Underlying mathematical operator using Operator = arch::OpMultiplyAdd; + static_assert( + !(Shape::kM % 2), + "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2." + ); + // // Methods // @@ -357,14 +360,10 @@ struct Mma_HFMA2< true > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; - static_assert( - !(Shape::kN % 2), - "Mma_HFMA2 requires the N dimension to be divisible by 2." - ); - - /// A operand storage + /// A operand storage using FragmentA = Array; /// B operand storage @@ -375,6 +374,12 @@ struct Mma_HFMA2< /// Underlying mathematical operator using Operator = arch::OpMultiplyAdd; + + static_assert( + !(Shape::kN % 2), + "Mma_HFMA2 requires the N dimension to be divisible by 2." + ); + // // Methods // @@ -448,14 +453,10 @@ struct Mma_HFMA2 < true > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; - static_assert( - !(Shape::kM % 2), - "Mma_HFMA2 requires the M dimension to be divisible by 2." - ); - - /// A operand storage + /// A operand storage using FragmentA = Array; /// B operand storage @@ -467,6 +468,11 @@ struct Mma_HFMA2 < /// Underlying mathematical operator using Operator = arch::OpMultiplyAdd; + static_assert( + !(Shape::kM % 2), + "Mma_HFMA2 requires the M dimension to be divisible by 2." + ); + // // Methods // @@ -543,13 +549,9 @@ struct Mma_HFMA2 < true > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; - static_assert( - !(Shape::kN % 2), - "Mma_HFMA2 requires the N dimension to be divisible by 2." - ); - /// A operand storage using FragmentA = Array; @@ -562,6 +564,11 @@ struct Mma_HFMA2 < /// Underlying mathematical operator using Operator = arch::OpMultiplyAdd; + static_assert( + !(Shape::kN % 2), + "Mma_HFMA2 requires the N dimension to be divisible by 2." + ); + // // Methods // @@ -638,13 +645,9 @@ struct Mma_HFMA2 < true > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; - static_assert( - !(Shape::kM % 2), - "Mma_HFMA2 requires the M dimension to be divisible by 2." - ); - /// A operand storage using FragmentA = Array; @@ -657,6 +660,11 @@ struct Mma_HFMA2 < /// Underlying mathematical operator using Operator = arch::OpMultiplyAdd; + static_assert( + !(Shape::kM % 2), + "Mma_HFMA2 requires the M dimension to be divisible by 2." + ); + // // Methods // @@ -734,14 +742,10 @@ struct Mma_HFMA2< true > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; - static_assert( - !(Shape::kN % 2), - "Mma_HFMA2 requires the N dimension to be divisible by 2." - ); - - /// A operand storage + /// A operand storage using FragmentA = Array; /// B operand storage @@ -753,6 +757,11 @@ struct Mma_HFMA2< /// Underlying mathematical operator using Operator = arch::OpMultiplyAdd; + static_assert( + !(Shape::kN % 2), + "Mma_HFMA2 requires the N dimension to be divisible by 2." + ); + // // Methods // @@ -825,14 +834,10 @@ struct Mma_HFMA2< false > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; - static_assert( - !(Shape::kK % 2), - "Mma_HFMA2 requires the K dimension to be divisible by 2." - ); - - /// A operand storage + /// A operand storage using FragmentA = Array; /// B operand storage @@ -844,6 +849,11 @@ struct Mma_HFMA2< /// Underlying mathematical operator using Operator = arch::OpMultiplyAdd; + static_assert( + !(Shape::kK % 2), + "Mma_HFMA2 requires the K dimension to be divisible by 2." + ); + // // Methods // @@ -909,14 +919,10 @@ struct Mma_HFMA2< false > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; - static_assert( - !(Shape::kK % 2), - "Mma_HFMA2 requires the K dimension to be divisible by 2." - ); - - /// A operand storage + /// A operand storage using FragmentA = Array; /// B operand storage @@ -927,7 +933,12 @@ struct Mma_HFMA2< /// Underlying mathematical operator using Operator = arch::OpMultiplyAdd; - + + static_assert( + !(Shape::kK % 2), + "Mma_HFMA2 requires the K dimension to be divisible by 2." + ); + // // Methods // diff --git a/include/cutlass/gemm/threadblock/default_mma.h b/include/cutlass/gemm/threadblock/default_mma.h index 15550809..48eb2ebf 100644 --- a/include/cutlass/gemm/threadblock/default_mma.h +++ b/include/cutlass/gemm/threadblock/default_mma.h @@ -86,7 +86,9 @@ template < typename Operator, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false + bool AccumulatorsInRowMajor = false, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill = false > struct DefaultMma; @@ -108,6 +110,8 @@ template < int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) @@ -119,13 +123,19 @@ template < /// Operation performed by GEMM typename Operator> struct DefaultMma { + + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassSimt, 2, Operator>; // Define iterators over tiles from the A operand @@ -144,7 +154,7 @@ struct DefaultMma; + LayoutC, typename MmaCore::MmaPolicy>; }; //////////////////////////////////////////////////////////////////////////////// @@ -342,6 +352,8 @@ template < int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) @@ -356,13 +368,18 @@ template < typename Operator > struct DefaultMma { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, - ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassSimt, Stages, Operator>; // Define iterators over tiles from the A operand @@ -385,7 +402,7 @@ struct DefaultMma; }; @@ -407,6 +424,8 @@ template < int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) @@ -418,12 +437,19 @@ template < /// Number of stages used in the multistage mainloop int Stages, /// Operation perfomed by GEMM - typename Operator + typename Operator, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill > struct DefaultMma { + InstructionShape, Stages, Operator, false, UseZfill> { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global @@ -437,7 +463,7 @@ struct DefaultMma; // Define iterators over tiles from the A operand @@ -460,8 +486,8 @@ struct DefaultMma; + MmaCore::kCacheOpB, ElementAccumulator, LayoutC, + typename MmaCore::MmaPolicy, Stages, UseZfill>; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_mma_core.h b/include/cutlass/gemm/threadblock/default_mma_core.h index 5a5426f4..49a58da0 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core.h +++ b/include/cutlass/gemm/threadblock/default_mma_core.h @@ -87,9 +87,9 @@ template < cutlass::arch::OpMultiplyAdd>::type, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false + bool AccumulatorsInRowMajor = false, /// Cache operation of operand A - , cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Kind CacheOpA = cutlass::arch::CacheOperation::Global, /// Cache operation of operand B cutlass::arch::CacheOperation::Kind CacheOpB = diff --git a/include/cutlass/gemm/threadblock/default_mma_core_simt.h b/include/cutlass/gemm/threadblock/default_mma_core_simt.h index 2ec882cc..4364e33f 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_simt.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_simt.h @@ -364,6 +364,9 @@ struct DefaultMmaCore, ElementA_, static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + // these should have max of thread tile also using LaneMmaShape = cutlass::gemm::GemmShape< LaneM, @@ -526,6 +529,9 @@ struct DefaultMmaCore, ElementA_, static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + static_assert(!(kPaddingM % LaneM), + "Padding must be divisible by Lane"); + // these should have max of thread tile also using LaneMmaShape = cutlass::gemm::GemmShape< LaneM, @@ -688,6 +694,9 @@ struct DefaultMmaCore, ElementA_, static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + static_assert(!(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + // these should have max of thread tile also using LaneMmaShape = cutlass::gemm::GemmShape< LaneM, @@ -721,6 +730,354 @@ struct DefaultMmaCore, ElementA_, ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization: +/// +/// A: column-major +/// B: row-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Operation performed by GEMM + typename Operator_> +struct DefaultMmaCore, ElementA_, + layout::AffineRank2ColumnMajor, ElementB_, layout::AffineRank2RowMajor, + ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ + > { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::AffineRank2ColumnMajor; + using ElementB = ElementB_; + using LayoutB = layout::AffineRank2RowMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Operation performed by GEMM + typename Operator_> +struct DefaultMmaCore, ElementA_, + layout::AffineRank2RowMajor, ElementB_, layout::AffineRank2ColumnMajor, + ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ + > { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::AffineRank2RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::AffineRank2ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Operation performed by GEMM + typename Operator_> +struct DefaultMmaCore, ElementA_, + layout::AffineRank2RowMajor, ElementB_, layout::AffineRank2RowMajor, ElementC_, + LayoutC_, arch::OpClassSimt, 2, Operator_ + > { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::AffineRank2RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::AffineRank2RowMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: column-major +/// B: column-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Operation performed by GEMM + typename Operator_> +struct DefaultMmaCore, ElementA_, + layout::AffineRank2ColumnMajor, ElementB_, layout::AffineRank2ColumnMajor, + ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ + > { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::AffineRank2ColumnMajor; + using ElementB = ElementB_; + using LayoutB = layout::AffineRank2ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Partial specialization: /// /// A: column-major diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h index 8b0c0de6..d0a11cef 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h @@ -532,6 +532,379 @@ struct DefaultMmaCore, WarpCount::kK>; }; +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for double-precision +/// +/// A: column-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = double; + using LayoutA = layout::AffineRank2ColumnMajor; + using ElementB = double; + using LayoutB = layout::AffineRank2ColumnMajor; + using ElementC = double; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; +}; + +/// Partial specialization for double-precision +/// +/// A: column-major +/// B: row-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = double; + using LayoutA = layout::AffineRank2ColumnMajor; + using ElementB = double; + using LayoutB = layout::AffineRank2RowMajor; + using ElementC = double; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for double-precision +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = double; + using LayoutA = layout::AffineRank2RowMajor; + using ElementB = double; + using LayoutB = layout::AffineRank2ColumnMajor; + using ElementC = double; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// +/// Partial specialization for double-precision +/// +/// A: row-major +/// B: row-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = double; + using LayoutA = layout::AffineRank2RowMajor; + using ElementB = double; + using LayoutB = layout::AffineRank2RowMajor; + using ElementC = double; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; +}; //////////////////////////////////////////////////////////////////////////////// @@ -1639,6 +2012,10 @@ struct DefaultMmaCore::value; static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + + static_assert(!((Shape::kK / 32) % LaneN), + "Padding must be divisible by Lane"); + // these should have max of thread tile also using LaneMmaShape = cutlass::gemm::GemmShape< LaneM, @@ -1947,6 +2324,10 @@ struct DefaultMmaCore::value; static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + + static_assert(!((Shape::kK / 32) % LaneM) && !((Shape::kK / 32) % LaneN), + "Padding must be divisible by Lane"); + // these should have max of thread tile also using LaneMmaShape = cutlass::gemm::GemmShape< LaneM, @@ -2100,6 +2481,10 @@ struct DefaultMmaCore::value; static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + + static_assert(!((Shape::kK / 32) % LaneM), + "Padding must be divisible by Lane"); + // these should have max of thread tile also using LaneMmaShape = cutlass::gemm::GemmShape< LaneM, @@ -2130,6 +2515,388 @@ struct DefaultMmaCore; }; +/// Partial specialization for SIMT GEMMs using multistage pipeline. +/// +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by Simt + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::AffineRank2ColumnMajor; + using ElementB = ElementB_; + using LayoutB = layout::AffineRank2RowMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; +}; + +/// Partial specialization for SIMT GEMMs using multistage pipeline. +/// +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by Simt + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::AffineRank2RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::AffineRank2ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; +}; + +/// Partial specialization for SIMT GEMMs using multistage pipeline. +/// +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by Simt + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::AffineRank2ColumnMajor; + using ElementB = ElementB_; + using LayoutB = layout::AffineRank2ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; + +}; + +/// Partial specialization for SIMT GEMMs using multistage pipeline. +/// +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by Simt + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::AffineRank2RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::AffineRank2RowMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + /// Default Operator + using Operator = Operator_; + + using Base = DefaultMmaCore; + + // + // Shared memory layouts + // + + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = typename Base::SmemIteratorA; + + /// Policy of iterator B + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = typename Base::SmemIteratorB; + + // + // Warp-level matrix multiply operator + // + + /// Policy used to define MmaPipelined + using MmaPolicy = typename Base::MmaPolicy; + +}; + //////////////////////////////////////////////////////////////////////////////// } // namespace threadblock diff --git a/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h b/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h new file mode 100644 index 00000000..fdf6ac26 --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Defines basic properties needed by CTA-level GEMMs assuming + expectations about data layout of the global memory fragments, data types, + and internal tile sizes. + + Partial specializations for threadblock::Mma operations targeting TensorOp + instructions. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/gemm/threadblock/default_mma_core.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass/gemm/threadblock/mma_with_reduction_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Template defininng default matrix multiply operators inferred from threadblock tile size, +/// global memory data layout, and target math instruction. +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape_, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// + bool ReduceKForA_, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false// (is_complex::value || is_complex::value) +> +struct DefaultMmaWithReductionCore { + using Base = DefaultMmaCore; + using Shape = Shape_; + using IteratorThreadMapA = typename Base::IteratorThreadMapA; + using IteratorThreadMapB = typename Base::IteratorThreadMapB; + using SmemIteratorA = typename Base::SmemIteratorA; + using SmemIteratorB = typename Base::SmemIteratorB; + using SmemLayoutA = typename Base::SmemLayoutA; + using SmemLayoutB = typename Base::SmemLayoutB; + using WarpCount = typename Base::WarpCount; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaWithReductionTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, ReduceKForA_, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/include/cutlass/gemm/threadblock/default_mma_with_reduction.h b/include/cutlass/gemm/threadblock/default_mma_with_reduction.h new file mode 100644 index 00000000..f0db7b07 --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_mma_with_reduction.h @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/gemm/threadblock/default_mma_core_with_reduction.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// + bool ReduceKForA_, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Use zfill or predicate for SM80 out-of-bound cp.async + bool UseZfill = false + > +struct DefaultMmaWithReduction { + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaWithReductionCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + ReduceKForA_, Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaWithReductionMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, Stages, UseZfill>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_base.h b/include/cutlass/gemm/threadblock/mma_base.h index a56d81f0..c02e6aa9 100644 --- a/include/cutlass/gemm/threadblock/mma_base.h +++ b/include/cutlass/gemm/threadblock/mma_base.h @@ -35,6 +35,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_shape.h" #include "cutlass/numeric_types.h" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { diff --git a/include/cutlass/gemm/threadblock/mma_multistage.h b/include/cutlass/gemm/threadblock/mma_multistage.h index d07b236d..851cdaee 100644 --- a/include/cutlass/gemm/threadblock/mma_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_multistage.h @@ -77,6 +77,8 @@ template < typename Policy_, /// Number of stages, int Stages, + /// Use zfill or predicate for out-of-bound cp.async + bool UseZfill = false, /// Used for partial specialization typename Enable = bool> class MmaMultistage : @@ -228,8 +230,13 @@ public: for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_A.get(); - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_A.valid()); + if (UseZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } ++iterator_A; } @@ -258,8 +265,13 @@ public: for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); + if (UseZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } ++iterator_B; } @@ -514,10 +526,12 @@ public: } - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); + if (UseZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } } }; diff --git a/include/cutlass/gemm/threadblock/mma_planar_complex_base.h b/include/cutlass/gemm/threadblock/mma_planar_complex_base.h index 22c9b3f8..202c67af 100644 --- a/include/cutlass/gemm/threadblock/mma_planar_complex_base.h +++ b/include/cutlass/gemm/threadblock/mma_planar_complex_base.h @@ -35,6 +35,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_shape.h" #include "cutlass/numeric_types.h" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { diff --git a/include/cutlass/gemm/threadblock/mma_sparse_base.h b/include/cutlass/gemm/threadblock/mma_sparse_base.h index eb192f72..919da8e4 100644 --- a/include/cutlass/gemm/threadblock/mma_sparse_base.h +++ b/include/cutlass/gemm/threadblock/mma_sparse_base.h @@ -35,6 +35,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_shape.h" #include "cutlass/numeric_types.h" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { diff --git a/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h b/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h new file mode 100644 index 00000000..20e64850 --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h @@ -0,0 +1,551 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + bool UseZfill = false, + /// Used for partial specialization + typename Enable = bool> +class MmaWithReductionMultistage : + public MmaBase { +public: + ///< Base class + using Base = MmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + using FragmentReduction = typename Operator::FragmentReduction; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static int const kReduceKForA = Operator::kReduceKForA; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaWithReductionMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (UseZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (UseZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const &src_accum, + FragmentReduction &gemm_k_reduction_accum) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + if (gemm_k_iterations == 0) { + iterator_A.clear_mask(); + iterator_B.clear_mask(); + } + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (gemm_k_iterations == 0) { + iterator_A.clear_mask(); + iterator_B.clear_mask(); + } + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum, + gemm_k_reduction_accum + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + if (gemm_k_iterations == 0) { + iterator_A.clear_mask(); + iterator_B.clear_mask(); + } + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + } + + if (UseZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle.h b/include/cutlass/gemm/threadblock/threadblock_swizzle.h index 79314088..65c42568 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle.h @@ -33,6 +33,8 @@ #include "cutlass/layout/matrix.h" #include "cutlass/platform/platform.h" #include "cutlass/gemm/gemm.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -105,9 +107,8 @@ struct GemmIdentityThreadblockSwizzle { CUTLASS_HOST_DEVICE GemmIdentityThreadblockSwizzle() { } - int const kTile = N; - /// Returns the shape of the problem in units of logical tiles + /// *Gemm* problem size: gemm(M, N, K) CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape( GemmCoord problem_size, @@ -120,19 +121,77 @@ struct GemmIdentityThreadblockSwizzle { split_k_slices); } + /// Returns the shape of the problem in units of logical tiles + /// *ImplicitGemm* Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) + CUTLASS_HOST_DEVICE + GemmCoord get_tiled_shape( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + GemmCoord tile_size, + int split_k_slices) const { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + return get_tiled_shape( + implicit_gemm_problem_size, tile_size, split_k_slices); + } + + /// Returns the shape of the problem in units of logical tiles + /// *ImplicitGemm* Conv3d problem size: conv_operator(NZPQK, NDHWC, KTRSC) + CUTLASS_HOST_DEVICE + GemmCoord get_tiled_shape( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv3dProblemSize const &problem_size, + GemmCoord tile_size, + int split_k_slices) const { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + return get_tiled_shape( + implicit_gemm_problem_size, tile_size, split_k_slices); + } + /// Computes CUDA grid dimensions given a size in units of logical tiles CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const { - if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) - return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); + int tile = 1 << get_log_tile(tiled_shape); + return dim3(tiled_shape.m() * tile, (tiled_shape.n() + tile - 1) / tile, tiled_shape.k()); + } - return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k()); + /// Calculates optimal swizzle width + CUTLASS_HOST_DEVICE + int get_log_tile(GemmCoord tiled_shape) const { + auto n = tiled_shape.n(); + // Thresholds picked so that it doesn't cause too many no-op CTAs + if (N >= 8 && n >= 6) + return 3; + else if (N >= 4 && n >= 3) + return 2; + else if (N >= 2 && n >= 2) + return 1; + else + return 0; + } + + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + GemmCoord get_tile_offset(int log_tile) const { + int block_idx_x = RematerializeBlockIdxX(); + int block_idx_y = RematerializeBlockIdxY(); + int block_idx_z = RematerializeBlockIdxZ(); + + return GemmCoord{(block_idx_x >> log_tile), // + (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), + block_idx_z}; } /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE GemmCoord get_tile_offset(GemmCoord tiled_shape) const { + int const kTile = N; int block_idx_x = RematerializeBlockIdxX(); int block_idx_y = RematerializeBlockIdxY(); @@ -174,6 +233,12 @@ struct GemmHorizontalThreadblockSwizzle { return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k()); } + /// Calculates optimal swizzle width + CUTLASS_HOST_DEVICE + int get_log_tile(GemmCoord tiled_shape) const { + return 0; + } + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE GemmCoord get_tile_offset(GemmCoord tiled_shape) const { @@ -209,6 +274,12 @@ struct GemmBatchedIdentityThreadblockSwizzle { return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); } + /// Calculates optimal swizzle width + CUTLASS_HOST_DEVICE + int get_log_tile(GemmCoord tiled_shape) const { + return 0; + } + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE GemmCoord get_tile_offset(GemmCoord tiled_shape) const { @@ -219,6 +290,18 @@ struct GemmBatchedIdentityThreadblockSwizzle { }; } + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + GemmCoord get_tile_offset(int log_tile) const { + int block_idx_x = RematerializeBlockIdxX(); + int block_idx_y = RematerializeBlockIdxY(); + int block_idx_z = RematerializeBlockIdxZ(); + + return GemmCoord{(block_idx_x >> log_tile), // + (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), + block_idx_z}; + } + /// Gets the batch index CUTLASS_DEVICE int get_batch_idx() const { @@ -247,20 +330,45 @@ struct GemmSplitKIdentityThreadblockSwizzle { partitions); } + /// Calculates optimal swizzle width + CUTLASS_HOST_DEVICE + int get_log_tile(GemmCoord tiled_shape) const { + auto n = tiled_shape.n(); + // Thresholds picked so that it doesn't cause too many no-op CTAs + if (N >= 8 && n >= 6) + return 3; + else if (N >= 4 && n >= 3) + return 2; + else if (N >= 2 && n >= 2) + return 1; + else + return 0; + } + /// Computes CUDA grid dimensions given a size in units of logical tiles CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const { - if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) - return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); - - return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k()); + int tile = 1 << get_log_tile(tiled_shape); + return dim3(tiled_shape.m() * tile, (tiled_shape.n() + tile - 1) / tile, tiled_shape.k()); } + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + GemmCoord get_tile_offset(int log_tile) const { + int block_idx_x = RematerializeBlockIdxX(); + int block_idx_y = RematerializeBlockIdxY(); + int block_idx_z = RematerializeBlockIdxZ(); + + return GemmCoord{(block_idx_x >> log_tile), // + (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), + block_idx_z}; + } /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE GemmCoord get_tile_offset(GemmCoord tiled_shape) const { + int const kTile = N; int block_idx_x = RematerializeBlockIdxX(); int block_idx_y = RematerializeBlockIdxY(); @@ -299,6 +407,21 @@ struct GemmSplitKHorizontalThreadblockSwizzle { return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k()); } + /// Calculates optimal swizzle width + CUTLASS_HOST_DEVICE + int get_log_tile(GemmCoord tiled_shape) const { + return 0; + } + + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + GemmCoord get_tile_offset(int log_tile) const { + return GemmCoord{ + RematerializeBlockIdxY(), + RematerializeBlockIdxX(), + RematerializeBlockIdxZ() + }; + } /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE @@ -335,6 +458,23 @@ struct GemvBatchedStridedThreadblockDefaultSwizzle { return dim3(tiled_shape.n(), tiled_shape.batch(), tiled_shape.k()); } + /// Calculates optimal swizzle width + CUTLASS_HOST_DEVICE + int get_log_tile(GemmCoord tiled_shape) const { + return 0; + } + + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + BatchedGemmCoord get_tile_offset(int log_tile) const { + return BatchedGemmCoord{ + 0, // M is always 1 + RematerializeBlockIdxX(), + RematerializeBlockIdxZ(), + RematerializeBlockIdxY(), + }; + } + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE BatchedGemmCoord get_tile_offset() const { diff --git a/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h b/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h new file mode 100644 index 00000000..1215db11 --- /dev/null +++ b/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h @@ -0,0 +1,84 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/mma_with_reduction_tensor_op.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Operator describing the tensor operation + typename Operator_ = arch::OpMultiplyAdd, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false> +struct DefaultMmaWithReductionTensorOp { + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaWithReductionTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_complex_tensor_op.h index 5877b95f..8c565cf5 100644 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op.h @@ -326,6 +326,9 @@ public: /// Shape of underlying instruction using InstructionShape = typename ArchMmaOperator::Shape; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + /// Complex transform on A operand static ComplexTransform const kTransformA = TransformA; @@ -618,6 +621,9 @@ public: /// Indicates class of matrix operator using OperatorClass = arch::OpClassTensorOp; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + /// Complex transform on A operand static ComplexTransform const kTransformA = TransformA; diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h b/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h index ba74fe96..94dd9520 100644 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h @@ -121,6 +121,9 @@ class MmaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -162,7 +165,7 @@ public: private: /// Layout object storing stride values - Index stride_; + StrideIndex stride_; /// Shared memory base pointers - not advanced AccessType const *pointer_; @@ -395,6 +398,9 @@ class MmaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -619,6 +625,9 @@ class MmaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -835,6 +844,9 @@ class MmaTensorOpAccumulatorTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -1159,6 +1171,9 @@ class MmaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -1200,7 +1215,7 @@ public: private: /// Layout object storing stride values - Index stride_; + StrideIndex stride_; /// Shared memory base pointers - not advanced AccessType const *pointer_; @@ -1441,6 +1456,9 @@ class MmaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -1666,6 +1684,9 @@ class MmaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -1901,6 +1922,9 @@ class MmaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -1946,7 +1970,7 @@ public: private: /// Layout object storing stride values - Index stride_; + StrideIndex stride_; /// Shared memory base pointers - not advanced AccessType const *pointer_; @@ -2207,6 +2231,9 @@ class MmaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -2249,7 +2276,7 @@ public: private: /// Layout object storing stride values - Index stride_; + StrideIndex stride_; /// Shared memory base pointers - not advanced AccessType const *pointer_; @@ -2305,6 +2332,18 @@ public: return *this; } + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { + + add_tile_offset(tile_offset); + + if (k_group_idx_ & 1) + byte_offset_ ^= 0x40; + + return *this; + } + /// Advances the iterator along the advance dimension CUTLASS_DEVICE MmaTensorOpMultiplicandTileIterator & operator++() { diff --git a/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h index 7cfad2ea..4a3111c9 100644 --- a/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h @@ -159,6 +159,9 @@ public: /// Indicates class of matrix operator using OperatorClass = arch::OpClassTensorOp; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + /// Complex transform on A operand static ComplexTransform const kTransformA = TransformA; diff --git a/include/cutlass/gemm/warp/mma_simt.h b/include/cutlass/gemm/warp/mma_simt.h index a86e06e4..b32e110e 100644 --- a/include/cutlass/gemm/warp/mma_simt.h +++ b/include/cutlass/gemm/warp/mma_simt.h @@ -154,6 +154,9 @@ public: /// Underlying matrix multiply operator (concept: arch::Mma) using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + /// Shape of the underlying instruction using InstructionShape = GemmShape<1,1,use_dp4a ? 4 : 1>; diff --git a/include/cutlass/gemm/warp/mma_simt_tile_iterator.h b/include/cutlass/gemm/warp/mma_simt_tile_iterator.h index 660db388..4198762e 100644 --- a/include/cutlass/gemm/warp/mma_simt_tile_iterator.h +++ b/include/cutlass/gemm/warp/mma_simt_tile_iterator.h @@ -354,18 +354,27 @@ private: /// Internal reference cutlass::TensorRef ref_; + /// Extent of tensor + MatrixCoord extent_; + + /// Origin + MatrixCoord origin_; + + /// Used to conditionally enable extents checking + bool divisible_; + public: /// Default ctor constructs null iterator CUTLASS_HOST_DEVICE - MmaSimtTileIterator() { } + MmaSimtTileIterator() : divisible_(true) { } /// Constructor from TensorRef CUTLASS_HOST_DEVICE MmaSimtTileIterator( TensorRef ref, int lane_id - ) { + ) : extent_(Shape::kRow, Shape::kColumn), divisible_ (true) { // compute offset based on thread ID and lane layout typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); @@ -373,12 +382,35 @@ public: MatrixCoord lane_offset = lane_layout.inverse(lane_id) * MatrixCoord(Policy::LaneMmaShape::kM, 0); + origin_ = lane_offset; + ref.add_coord_offset(lane_offset); ref_.reset(ref.data(), ref.stride(0)); } + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaSimtTileIterator( + TensorRef ref, + TensorCoord extent, + int lane_id + ) : extent_(extent), divisible_ (false) { + + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, 0); + + origin_ = lane_offset; + + ref.add_coord_offset(lane_offset); + + ref_.reset(ref.data(), ref.stride(0)); + + } /// Adds a pointer offset to internal pointer(s) to advance through memory CUTLASS_HOST_DEVICE @@ -391,9 +423,13 @@ public: CUTLASS_HOST_DEVICE MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - ref_.add_coord_offset({ + TensorCoord coord_offset( coord.row() * Shape::kRow, - coord.column() * Shape::kColumn}); + coord.column() * Shape::kColumn); + + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); return *this; } @@ -426,11 +462,21 @@ public: for (int m = 0; m < Iterations::kRow; ++m) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < Policy::LaneMmaShape::kM; i++) { - - frag[m * Policy::LaneMmaShape::kM + i + k * Iterations::kRow] = - *(ref_.data() + - ref_.offset({m * Policy::WarpShape::kRow * Policy::LaneMmaShape::kM + i, k}) + - pointer_offset); + + MatrixCoord offset(m * Policy::WarpShape::kRow * Policy::LaneMmaShape::kM + i, k); + + MatrixCoord access_coord = origin_ + offset; + + int frag_idx = m * Policy::LaneMmaShape::kM + i + k * Iterations::kRow; + + if (divisible_ || + (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { + + frag[frag_idx] = *(ref_.data() + ref_.offset(offset) + pointer_offset); + } + else { + frag[frag_idx] = Element(); + } } } } @@ -765,18 +811,27 @@ private: /// Internal reference cutlass::TensorRef ref_; + /// Extent of tensor + MatrixCoord extent_; + + /// Origin + MatrixCoord origin_; + + /// Used to conditionally enable extents checking + bool divisible_; + public: /// Default ctor constructs null iterator CUTLASS_HOST_DEVICE - MmaSimtTileIterator() { } + MmaSimtTileIterator(): divisible_(true) { } /// Constructor from TensorRef CUTLASS_HOST_DEVICE MmaSimtTileIterator( TensorRef ref, int lane_id - ) { + ): extent_(Shape::kRow, Shape::kColumn), divisible_(true) { // compute offset based on thread ID and lane layout typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); @@ -784,11 +839,34 @@ public: MatrixCoord lane_offset = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + origin_ = lane_offset; + ref.add_coord_offset(lane_offset); ref_.reset(ref.data(), ref.stride(0)); } - + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaSimtTileIterator( + TensorRef ref, + TensorCoord extent, + int lane_id + ): extent_(extent), divisible_(false) { + + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(0, Policy::LaneMmaShape::kN); + + origin_ = lane_offset; + + ref.add_coord_offset(lane_offset); + + ref_.reset(ref.data(), ref.stride(0)); + } + /// Adds a pointer offset to internal pointer(s) to advance through memory CUTLASS_HOST_DEVICE MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { @@ -800,9 +878,13 @@ public: CUTLASS_HOST_DEVICE MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { - ref_.add_coord_offset({ + TensorCoord coord_offset( coord.row() * Shape::kRow, - coord.column() * Shape::kColumn}); + coord.column() * Shape::kColumn); + + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); return *this; } @@ -835,10 +917,21 @@ public: for (int n = 0; n < Iterations::kColumn; ++n) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < Policy::LaneMmaShape::kN; ++i) { - frag[n * Policy::LaneMmaShape::kN + i + k * Iterations::kColumn] = - *(ref_.data() + - ref_.offset({k, n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + i}) + - pointer_offset); + + MatrixCoord offset(k, n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + i); + + MatrixCoord access_coord = origin_ + offset; + + int frag_idx = n * Policy::LaneMmaShape::kN + i + k * Iterations::kColumn; + + if (divisible_ || + (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { + + frag[frag_idx] = *(ref_.data() + ref_.offset(offset) + pointer_offset); + } + else { + frag[frag_idx] = Element(); + } } } } diff --git a/include/cutlass/gemm/warp/mma_sparse_tensor_op.h b/include/cutlass/gemm/warp/mma_sparse_tensor_op.h index 86c50d37..f0416d76 100644 --- a/include/cutlass/gemm/warp/mma_sparse_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_sparse_tensor_op.h @@ -119,6 +119,9 @@ public: /// Underlying matrix multiply operator (concept: arch::Mma) using ArchMmaOperator = typename Base::ArchMmaOperator; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + /// Architecture tag from underlying instruction using ArchTag = typename Base::ArchTag; diff --git a/include/cutlass/gemm/warp/mma_tensor_op.h b/include/cutlass/gemm/warp/mma_tensor_op.h index a6f83129..f098d520 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_tensor_op.h @@ -187,6 +187,9 @@ public: /// Underlying matrix multiply operator (concept: arch::Mma) using ArchMmaOperator = typename Policy::Operator; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + /// Architecture tag from underlying instruction using ArchTag = typename ArchMmaOperator::ArchTag; @@ -400,3 +403,4 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h index e7a77f72..bbcedfcf 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h @@ -90,9 +90,6 @@ class MmaTensorOpFragmentIterator +class MmaTensorOpAccumulatorTileIterator< + Shape_, Element_, cutlass::layout::AffineRankN<2>, InstructionShape_, OpDelta_> { + public: + + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand::kC; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + using OpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static bool const kDivisible = + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN); + + static_assert(platform::is_same::value, + "Layouts must be defined for logical MatrixCoord coordinate space."); + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, + (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN + >; + }; + +private: + + // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire + // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements + // of that row. The accumulators within one row are assumed to be consecutive. + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>; + +private: + + /// Reference to output tensor + TensorRef ref_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator( + TensorRef const &ref, + int lane_id + ): + ref_(ref) { + + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + + MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); + + ref_.add_coord_offset(lane_offset); + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator & operator++() { + // deliberate no-op + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpAccumulatorTileIterator & operator--() { + // deliberate no-op + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + Fragment &frag, ///< fragment to load from the tensor + Index pointer_offset) const { ///< loads a tile with a linear offset + + TensorRef offset_ref(ref_); + offset_ref.add_pointer_offset(pointer_offset); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile; + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; + + frag[mma_accum_start + row * kElementsPerAccess + col] = offset_ref.at({accum_m, accum_n}); + } + } + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + Fragment &frag, ///< fragment to load from the tensor + Index byte_offset) const { ///< loads a tile with a linear offset + + load_with_pointer_offset(byte_offset / sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + Fragment &frag, ///< fragment to load from the tensor + TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles + + load(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + Fragment &frag, ///< fragment to load from the tensor + TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles + Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset + + load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); + } + + /// Stores a fragment to memory + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) const { + store_with_pointer_offset(frag, 0); + } + + /// Stores a fragment to memory with additional pointer offset + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, ///< fragment to store from the tensor + Index pointer_offset) const { ///< store a tile with a linear offset + + TensorRef offset_ref(ref_); + offset_ref.add_pointer_offset(pointer_offset); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile; + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; + int idx = mma_accum_start + row * kElementsPerAccess + col; + + offset_ref.at({accum_m, accum_n}) = frag[idx]; + } + } + } + } + } + + /// Stores a fragment to memory with additional pointer offset + CUTLASS_DEVICE + void store_with_byte_offset( + Fragment const &frag, ///< fragment to store from the tensor + Index byte_offset) const { ///< store a tile with a linear offset + + store_with_pointer_offset(byte_offset / sizeof(Element)); + } + + /// Stores a fragment to memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void store( + Fragment &frag, ///< fragment to store to the tensor + TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles + + store(frag, tile_offset, 0); + } + + /// Stores a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void store( + /// fragment to store to the tensor + Fragment const &frag, + /// stores a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// stores a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + /// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store /// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major /// accumulator layout. @@ -3289,6 +3602,9 @@ class MmaTensorOpAccumulatorTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h index 4be831f3..716d476d 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h @@ -123,6 +123,9 @@ class MmaVoltaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -171,7 +174,7 @@ public: private: /// Layout object storing stride values - Index stride_; + StrideIndex stride_; /// Shared memory base pointers - not advanced AccessType const *pointer_[kPointerCount]; @@ -436,6 +439,9 @@ class MmaVoltaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -480,7 +486,7 @@ public: private: /// Layout object storing stride values - Index stride_; + StrideIndex stride_; /// Shared memory base pointers - not advanced AccessType const *pointer_; @@ -1526,6 +1532,9 @@ class MmaVoltaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -1566,7 +1575,7 @@ class MmaVoltaTensorOpMultiplicandTileIterator< private: /// Layout object storing stride values - Index stride_; + StrideIndex stride_; /// Shared memory base pointers - not advanced AccessType const *pointer_; diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h index 4d45ecf5..b5709544 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h @@ -121,6 +121,9 @@ class MmaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -166,7 +169,7 @@ public: private: /// Layout object storing stride values - Index stride_; + StrideIndex stride_; /// Shared memory base pointers - not advanced AccessType const *pointer_; @@ -877,6 +880,9 @@ class MmaTensorOpMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Long Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -919,7 +925,7 @@ public: private: /// Layout object storing stride values - Index stride_; + StrideIndex stride_; /// Shared memory base pointers - not advanced AccessType const *pointer_; @@ -982,6 +988,16 @@ public: return *this; } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { + + add_tile_offset(tile_offset); // TODO fix this if it becomes an issue during warp it reset + + return *this; + } + /// Advances the iterator along the advance dimension CUTLASS_DEVICE MmaTensorOpMultiplicandTileIterator & operator++() { @@ -1237,6 +1253,15 @@ public: return *this; } + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); + + return *this; + } + /// Advances the iterator along the advance dimension CUTLASS_HOST_DEVICE MmaTensorOpMultiplicandTileIterator & operator++() { @@ -1461,6 +1486,15 @@ public: return *this; } + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset_negative({tile_offset.row(), tile_offset.column()}); + + return *this; + } + /// Advances the iterator along the advance dimension CUTLASS_HOST_DEVICE MmaTensorOpMultiplicandTileIterator & operator++() { diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h index 6fd783c6..e6505a93 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h @@ -130,6 +130,9 @@ class MmaTensorOpWmmaMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Stride Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -180,7 +183,7 @@ private: Index byte_offset_; /// Stride in units of number of elements - Index stride_; + StrideIndex stride_; /// Layout of shared memory Layout layout_; @@ -375,6 +378,9 @@ class MmaTensorOpWmmaMultiplicandTileIterator< /// Long Index type using LongIndex = typename TensorRef::LongIndex; + /// Stride Index type + using StrideIndex = typename TensorRef::Layout::Stride::Index; + /// Coordinate for an element in the tensor using TensorCoord = typename TensorRef::TensorCoord; @@ -425,7 +431,7 @@ private: Index byte_offset_; /// Stride in units of number of elements - Index stride_; + StrideIndex stride_; /// Layout of shared memory Layout layout_; diff --git a/include/cutlass/gemm/warp/mma_tensor_op_wmma.h b/include/cutlass/gemm/warp/mma_tensor_op_wmma.h index c000dd62..8c26b9ec 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_wmma.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_wmma.h @@ -109,6 +109,12 @@ public: /// Underlying instruction shape using InstructionShape = typename Policy::Operator::Shape; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + /// Underlying architecture tag using ArchTag = typename Policy::Operator::ArchTag; diff --git a/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h b/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h new file mode 100644 index 00000000..ace57427 --- /dev/null +++ b/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h @@ -0,0 +1,405 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// + bool ReduceKForA_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool +> +class MmaWithReductionTensorOp { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + static bool const kReduceKForA = ReduceKForA_; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN + >; + + using FragmentReduction = Array; + +public: + + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaWithReductionTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C, + FragmentReduction &gemm_k_reduction + ) const { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + MmaOperandA const *ptr_A = reinterpret_cast(&A); + MmaOperandB const *ptr_B = reinterpret_cast(&B); + MmaOperandC *ptr_D = reinterpret_cast(&D); + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma( + ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } + #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + + if (!kReduceKForA && m == 0) { +// gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4]); +// gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 1]); +// gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 2]); +// gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 3]); + + uint32_t const *tmp = reinterpret_cast(&B); + asm volatile( + "{\n\t" + " .reg .f16 low, high;\n\t" + " .reg .f32 tmp;\n\t" + " mov.b32 {low, high}, %1;\n\t" + " cvt.f32.f16 tmp, low;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " cvt.f32.f16 tmp, high;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " mov.b32 {low, high}, %2;\n\t" + " cvt.f32.f16 tmp, low;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " cvt.f32.f16 tmp, high;\n\t" + " add.f32 %0, tmp, %0;\n\t" + "}\n\t" + : "+f"(gemm_k_reduction[n_serpentine]) + : "r"(tmp[n_serpentine * 2]), "r"(tmp[n_serpentine * 2 + 1])); + } + } + + if (kReduceKForA && (n == 0)) { +// gemm_k_reduction[m * 2] += float(A[m * 8]); +// gemm_k_reduction[m * 2] += float(A[m * 8 + 1]); +// gemm_k_reduction[m * 2] += float(A[m * 8 + 4]); +// gemm_k_reduction[m * 2] += float(A[m * 8 + 5]); +// +// gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 2]); +// gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 3]); +// gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 6]); +// gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 7]); + + uint32_t const *tmp = reinterpret_cast(&A); + asm volatile( + "{\n\t" + " .reg .f16 low, high;\n\t" + " .reg .f32 tmp;\n\t" + " mov.b32 {low, high}, %2;\n\t" + " cvt.f32.f16 tmp, low;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " cvt.f32.f16 tmp, high;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " mov.b32 {low, high}, %3;\n\t" + " cvt.f32.f16 tmp, low;\n\t" + " add.f32 %1, tmp, %1;\n\t" + " cvt.f32.f16 tmp, high;\n\t" + " add.f32 %1, tmp, %1;\n\t" + " mov.b32 {low, high}, %4;\n\t" + " cvt.f32.f16 tmp, low;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " cvt.f32.f16 tmp, high;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " mov.b32 {low, high}, %5;\n\t" + " cvt.f32.f16 tmp, low;\n\t" + " add.f32 %1, tmp, %1;\n\t" + " cvt.f32.f16 tmp, high;\n\t" + " add.f32 %1, tmp, %1;\n\t" + "}\n\t" + : "+f"(gemm_k_reduction[m * 2]), "+f"(gemm_k_reduction[m * 2 + 1]) + : "r"(tmp[m * 4]), "r"(tmp[m * 4 + 1]),"r"(tmp[m * 4 + 2]), "r"(tmp[m * 4 + 3])); + } + } + } + #else + assert(0); + #endif + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + + // + // Define conversions from source type to instruction type + // + FloatRoundStyle const kRoundA = + PreferredRoundingMode::kRound; + FloatRoundStyle const kRoundB = + PreferredRoundingMode::kRound; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + detail::ConvertAndPack + convert_A; + NumericArrayConverter + convert_B; + Array const *ptr_B = + reinterpret_cast const *>(&B); + Array * + ptr_dst_B = reinterpret_cast *>(&dst_B); + + dst_A = convert_A(A); + + ptr_dst_B[0] = convert_B(ptr_B[0]); + ptr_dst_B[1] = convert_B(ptr_B[1]); + + #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + detail::ConvertAndPack + convert_A; + NumericArrayConverter + convert_B; + Array const *ptr_A = + reinterpret_cast const *>(&A); + Array * + ptr_dst_A = reinterpret_cast *>(&dst_A); + + dst_B = convert_B(B); + + ptr_dst_A[0] = convert_A(ptr_A[0]); + ptr_dst_A[1] = convert_A(ptr_A[1]); + #else + assert(0); + #endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/half.h b/include/cutlass/half.h index caa66657..c6e54b02 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -59,9 +59,13 @@ enum #define CUTLASS_ENABLE_F16C 0 #else +// +// Standard Library headers belong here to avoid conflicts with NVRTC. +// #include #include #include +#include #endif /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -211,7 +215,14 @@ struct alignas(2) half_t { #endif // software implementation rounds toward nearest even - unsigned const& s = reinterpret_cast(flt); + unsigned s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(flt); + #else + std::memcpy(&s, &flt, sizeof(s)); + #endif + uint16_t sign = uint16_t((s >> 16) & 0x8000); int16_t exp = uint16_t(((s >> 23) & 0xff) - 127); int mantissa = s & 0x7fffff; @@ -340,7 +351,13 @@ struct alignas(2) half_t { f = (0xff << 23) | (sign << 31); // inf } } + #if defined(__CUDA_ARCH__) return reinterpret_cast(f); + #else + float flt; + std::memcpy(&flt, &f, sizeof(flt)); + return flt; + #endif #endif } @@ -354,8 +371,13 @@ struct alignas(2) half_t { /// Reinterpret cast from CUDA's half type CUTLASS_HOST_DEVICE - explicit half_t(half const & x): storage(reinterpret_cast(x)) { - + explicit half_t(half const & x) { + #if defined(__CUDA_ARCH__) + storage = reinterpret_cast(x); + #else + __half_raw raw(x); + std::memcpy(&storage, &raw.x, sizeof(storage)); + #endif } /// Floating point conversion @@ -385,7 +407,12 @@ struct alignas(2) half_t { /// Assignment CUTLASS_HOST_DEVICE half_t & operator=(half const &x) { + #if defined(__CUDA_ARCH__) storage = reinterpret_cast(x); + #else + __half_raw raw(x); + std::memcpy(&storage, &raw.x, sizeof(storage)); + #endif return *this; } @@ -416,7 +443,13 @@ struct alignas(2) half_t { /// Bitcasts to CUDA's half type CUTLASS_HOST_DEVICE half to_half() const { + #if defined(__CUDA_ARCH__) return reinterpret_cast(storage); + #else + __half_raw raw; + std::memcpy(&raw.x, &storage, sizeof(raw.x)); + return half(raw); + #endif } /// Accesses raw internal state @@ -529,11 +562,11 @@ cutlass::half_t sqrt(cutlass::half_t const& h) { CUTLASS_HOST_DEVICE half_t copysign(half_t const& a, half_t const& b) { - uint16_t a_mag = (reinterpret_cast(a) & 0x7fff); - uint16_t b_sign = (reinterpret_cast(b) & 0x8000); + uint16_t a_mag = (a.raw() & 0x7fff); + uint16_t b_sign = (b.raw() & 0x8000); uint16_t result = (a_mag | b_sign); - return reinterpret_cast(result); + return half_t::bitcast(result); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -546,9 +579,9 @@ half_t copysign(half_t const& a, half_t const& b) { // /////////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass { -namespace platform { +namespace std { +#if !defined(__CUDACC_RTC__) /// Numeric limits template <> struct numeric_limits { @@ -594,9 +627,8 @@ struct numeric_limits { /// Returns smallest finite value static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } }; - -} // namespace platform -} // namespace cutlass +#endif +} // namespace std /////////////////////////////////////////////////////////////////////////////////////////////////// // diff --git a/include/cutlass/layout/layout.h b/include/cutlass/layout/layout.h index 4d78c4c4..c1170b08 100644 --- a/include/cutlass/layout/layout.h +++ b/include/cutlass/layout/layout.h @@ -49,6 +49,7 @@ namespace layout { /////////////////////////////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layout diff --git a/include/cutlass/layout/matrix.h b/include/cutlass/layout/matrix.h index 668245fc..c4467200 100644 --- a/include/cutlass/layout/matrix.h +++ b/include/cutlass/layout/matrix.h @@ -34,7 +34,9 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" #include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" namespace cutlass { namespace layout { @@ -64,7 +66,7 @@ public: using TensorCoord = MatrixCoord; /// Stride vector - using Stride = Coord; + using Stride = Coord; private: // @@ -81,7 +83,7 @@ public: /// Constructor CUTLASS_HOST_DEVICE - RowMajor(Index ldm = 0): stride_(ldm) { } + RowMajor(LongIndex ldm = 0): stride_(ldm) { } /// Ctor CUTLASS_HOST_DEVICE @@ -120,13 +122,13 @@ public: /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index stride(int idx) const { + typename Stride::Index stride(int idx) const { return stride_[idx]; } /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index & stride(int idx) { + typename Stride::Index & stride(int idx) { return stride_[idx]; } @@ -156,7 +158,7 @@ public: using TensorCoord = MatrixCoord; /// Stride vector - using Stride = Coord; + using Stride = Coord; private: // @@ -173,7 +175,7 @@ public: /// Ctor CUTLASS_HOST_DEVICE - ColumnMajor(Index ldm = 0): stride_(ldm) { } + ColumnMajor(LongIndex ldm = 0): stride_(ldm) { } /// Ctor CUTLASS_HOST_DEVICE @@ -213,13 +215,13 @@ public: /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index stride(int idx) const { + typename Stride::Index stride(int idx) const { return stride_[idx]; } /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index & stride(int idx) { + typename Stride::Index & stride(int idx) { return stride_[idx]; } @@ -251,7 +253,7 @@ struct RowMajorInterleaved { using TensorCoord = MatrixCoord; /// Stride vector - using Stride = Coord; + using Stride = Coord; /// Size of interleaved columns static int const kInterleave = Interleave; @@ -271,7 +273,7 @@ public: /// Ctor CUTLASS_HOST_DEVICE - RowMajorInterleaved(Index ldm = 0): stride_(ldm) { } + RowMajorInterleaved(LongIndex ldm = 0): stride_(ldm) { } /// Ctor CUTLASS_HOST_DEVICE @@ -319,13 +321,13 @@ public: /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index stride(int idx) const { + typename Stride::Index stride(int idx) const { return stride_[idx]; } /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index & stride(int idx) { + typename Stride::Index & stride(int idx) { return stride_[idx]; } @@ -357,7 +359,7 @@ struct ColumnMajorInterleaved { using TensorCoord = MatrixCoord; /// Stride vector - using Stride = Coord; + using Stride = Coord; /// Size of interleaved columns static int const kInterleave = Interleave; @@ -377,7 +379,7 @@ public: /// Ctor CUTLASS_HOST_DEVICE - ColumnMajorInterleaved(Index ldm = 0): stride_(ldm) { } + ColumnMajorInterleaved(LongIndex ldm = 0): stride_(ldm) { } /// Ctor CUTLASS_HOST_DEVICE @@ -426,13 +428,13 @@ public: /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index stride(int idx) const { + typename Stride::Index stride(int idx) const { return stride_[idx]; } /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index & stride(int idx) { + typename Stride::Index & stride(int idx) { return stride_[idx]; } @@ -469,7 +471,7 @@ struct ContiguousMatrix { using TensorCoord = MatrixCoord; /// Stride vector - using Stride = Coord; + using Stride = Coord; private: // @@ -548,13 +550,13 @@ public: /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index stride(int idx) const { + typename Stride::Index stride(int idx) const { return stride_[idx]; } /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index & stride(int idx) { + typename Stride::Index & stride(int idx) { return stride_[idx]; } @@ -574,6 +576,412 @@ public: } }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for scenario in which both rows and columns are separated by a stride. +template +struct AffineRankN { + + /// Logical rank of tensor + static int const kRank = Rank; + + /// Rank of stride vector + static int const kStrideRank = kRank; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = Coord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member + Stride stride_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + AffineRankN( + Stride const &stride = Stride() + ): + stride_(stride) { } + + /// Ctor + CUTLASS_HOST_DEVICE + AffineRankN( + Coord const &stride_m, + Coord const &stride_n + ) { + + // Concatenate the strides + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kRank/2; ++m) { + stride_[m] = stride_m[m]; + } + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kRank/2; ++n) { + stride_[n + kRank/2] = stride_n[n]; + } + } + + /// Ctor for N = 2 + CUTLASS_HOST_DEVICE + AffineRankN( + LongIndex const &stride_m, + LongIndex const &stride_n + ) { + stride_[0] = stride_m; + stride_[1] = stride_n; + } + + /// Ctor for N = 2 + CUTLASS_HOST_DEVICE + AffineRankN( + LongIndex const &stride + ) { + stride_[0] = stride; + stride_[1] = 1; + } + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static AffineRankN packed(TensorCoord const &extent) { + + AffineRankN layout; + layout.stride_[kRank - 1] = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = kRank - 1; i > 0; --i) { + layout.stride_[i - 1] = layout.stride_[i] * extent[i]; + } + + return layout; + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return dot(coord, stride_); + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + CUTLASS_HOST_DEVICE + TensorCoord inverse(LongIndex offset) const { + // TODO + return TensorCoord(); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const { + return stride_[idx]; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + typename Stride::Index & stride(int idx) { + return stride_[idx]; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + int idx = stride_.max_dim_index(); + return extent[idx] * stride_[idx]; + } +}; + +/// Mapping function for scenario in which both rows and columns are separated by a stride. +/// Row stride is smaller than column stride in AffineRank2ColumnMajor. +struct AffineRank2ColumnMajor { + + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 2; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = MatrixCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member + Stride stride_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + AffineRank2ColumnMajor( + Stride const &stride = Stride() + ): + stride_(stride) { } + + /// Ctor + CUTLASS_HOST_DEVICE + AffineRank2ColumnMajor( + LongIndex row_stride, ///< stride between elements in consecutive rows + LongIndex column_stride ///< stride between elements in consecutive columns + ) + { stride_[0] = row_stride; stride_[1] = column_stride;} + + /// Ctor + CUTLASS_HOST_DEVICE + AffineRank2ColumnMajor( + LongIndex stride + ) + { stride_[0] = 1; stride_[1] = stride;} + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static AffineRank2ColumnMajor packed(MatrixCoord const &extent) { + return AffineRank2ColumnMajor(extent.column(), 1); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord const &coord) const { + return dot(coord, stride_); + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + CUTLASS_HOST_DEVICE + MatrixCoord inverse(LongIndex offset) const { + // TODO + return MatrixCoord(0, 0); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const { + return stride_[idx]; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + typename Stride::Index & stride(int idx) { + return stride_[idx]; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(MatrixCoord const &extent) const { + return extent.column() * stride_[1]; + } +}; + +/// Mapping function for scenario in which both rows and columns are separated by a stride. +/// Column stride is smaller than row stride in AffineRank2RowMajor. +struct AffineRank2RowMajor { + + /// Logical rank of tensor + static int const kRank = 2; + + /// Rank of stride vector + static int const kStrideRank = 2; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = MatrixCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member + Stride stride_; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + AffineRank2RowMajor( + Stride const &stride = Stride() + ): + stride_(stride) { } + + /// Ctor + CUTLASS_HOST_DEVICE + AffineRank2RowMajor( + LongIndex row_stride, ///< stride between elements in consecutive rows + LongIndex column_stride ///< stride between elements in consecutive columns + ) { stride_[0] = row_stride; stride_[1] = column_stride;} + + /// Ctor + CUTLASS_HOST_DEVICE + AffineRank2RowMajor( + LongIndex stride + ) { stride_[0] = stride; stride_[1] = 1;} + + /// Helper returns a layout to a tightly packed tensor + CUTLASS_HOST_DEVICE + static AffineRank2RowMajor packed(MatrixCoord const &extent) { + return AffineRank2RowMajor(extent.column(), 1); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord const &coord) const { + return dot(coord, stride_); + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + CUTLASS_HOST_DEVICE + MatrixCoord inverse(LongIndex offset) const { + // TODO + return MatrixCoord(0, 0); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const { + return stride_[idx]; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + typename Stride::Index & stride(int idx) { + return stride_[idx]; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(MatrixCoord const &extent) const { + return extent.row() * stride_[0]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Utility functions to convert stride_factor to the strides used by the Affine2 layout. +// +// stride_factor is the logical distance between two coorinates. +// +// All Coodinates used here are matrix coordinates. stride[0] and extent[0] are for the +// rows. stride[1] and extent[1] are for the columns. +template + struct Affine2Layout_Factory { + CUTLASS_HOST_DEVICE + static Affine2Layout layout_factory(cutlass::Coord<2> const &extent, typename Affine2Layout::Stride stride_factor) { + return Affine2Layout::packed(extent); + } +}; + +template <> +struct Affine2Layout_Factory { +CUTLASS_HOST_DEVICE +static cutlass::layout::AffineRank2ColumnMajor layout_factory( + cutlass::Coord<2> const &extent, + typename cutlass::layout::AffineRank2ColumnMajor::Stride stride_factor) { + return cutlass::layout::AffineRank2ColumnMajor({ stride_factor[0], stride_factor[0] * stride_factor[1] * extent[0] }); + } +}; + +template <> +struct Affine2Layout_Factory { +CUTLASS_HOST_DEVICE +static cutlass::layout::AffineRank2RowMajor layout_factory( + cutlass::Coord<2> const &extent, + typename cutlass::layout::AffineRank2RowMajor::Stride stride_factor) { + return cutlass::layout::AffineRank2RowMajor({ stride_factor[0] * stride_factor[1] * extent[1], stride_factor[1] }); + } +}; + +// The base layout cutlass::layout::AffineRankN<2> is similar to AffineRank2ColumnMajor +template <> +struct Affine2Layout_Factory> { +CUTLASS_HOST_DEVICE +static cutlass::layout::AffineRankN<2> layout_factory( + cutlass::Coord<2> const &extent, + typename cutlass::layout::AffineRankN<2>::Stride stride_factor) { + return cutlass::layout::AffineRankN<2>({ stride_factor[0], stride_factor[0] * stride_factor[1] * extent[0] }); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Mapping function for block-linear matrices. Matrix is structured /// as column-major arrangement of 2D tiles (that are column-major). template @@ -594,7 +1002,7 @@ struct ColumnMajorBlockLinear { using TensorCoord = MatrixCoord; /// Stride vector - using Stride = Coord; + using Stride = Coord; /// Size of a block in rows static int const kBlockRows = BlockRows; @@ -658,13 +1066,13 @@ public: /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index stride(int idx) const { + typename Stride::Index stride(int idx) const { return stride_[idx]; } /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index & stride(int idx) { + typename Stride::Index & stride(int idx) { return stride_[idx]; } @@ -695,7 +1103,7 @@ struct RowMajorBlockLinear { using TensorCoord = MatrixCoord; /// Stride vector - using Stride = Coord; + using Stride = Coord; /// Size of a block in rows static int const kBlockRows = BlockRows; @@ -758,13 +1166,13 @@ public: /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index stride(int idx) const { + typename Stride::Index stride(int idx) const { return stride_[idx]; } /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index & stride(int idx) { + typename Stride::Index & stride(int idx) { return stride_[idx]; } @@ -887,13 +1295,13 @@ public: /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index stride(int idx) const { + typename Stride::Index stride(int idx) const { return stride_[idx]; } /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index & stride(int idx) { + typename Stride::Index & stride(int idx) { return stride_[idx]; } diff --git a/include/cutlass/layout/pitch_linear.h b/include/cutlass/layout/pitch_linear.h index a44825c1..92bc08a5 100644 --- a/include/cutlass/layout/pitch_linear.h +++ b/include/cutlass/layout/pitch_linear.h @@ -29,138 +29,14 @@ #include "cutlass/cutlass.h" #include "cutlass/coord.h" +#include "cutlass/pitch_linear_coord.h" namespace cutlass { namespace layout { -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Template defining a shape used by pitch-linear operators -template < - int Contiguous, - int Strided -> -struct PitchLinearShape { - static int const kContiguous = Contiguous; - static int const kStrided = Strided; - static int const kCount = Contiguous * Strided; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Coordinate in pitch-linear space -struct PitchLinearCoord : public Coord<2, int> { -public: - - /// Integer-valued index - using Index = int; - - /// Base type is a Coord of rank=2 - using Base = Coord<2, Index>; - -private: - - /// Rows dimension - static int const kContiguous = 0; - - /// Columns dimension - static int const kStrided = 1; - -public: - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - PitchLinearCoord() { } - - /// Constructs from Coord<2> - CUTLASS_HOST_DEVICE - PitchLinearCoord(Coord<2, Index> const &coord): Base(coord) { } - - /// Helper to construct from a row and column - CUTLASS_HOST_DEVICE - PitchLinearCoord(Index contiguous_, Index strided_): Base(make_Coord(contiguous_, strided_)) { } - - /// Returns the contiguous dimension - CUTLASS_HOST_DEVICE - Index const & contiguous() const { return this->at(kContiguous); } - - /// Returns the contiguous dimension - CUTLASS_HOST_DEVICE - Index & contiguous() { return this->at(kContiguous); } - - /// Returns the column of the coordinate - CUTLASS_HOST_DEVICE - Index const & strided() const { return this->at(kStrided); } - - /// Returns the column of the coordinate - CUTLASS_HOST_DEVICE - Index & strided() { return this->at(kStrided); } - - // - // Coord operators - // - - /// Element-wise addition - CUTLASS_HOST_DEVICE - PitchLinearCoord operator+(Base const& b) const { - return PitchLinearCoord(Base::operator+(b)); - } - - /// Element-wise subtraction - CUTLASS_HOST_DEVICE - PitchLinearCoord operator-(Base const& b) const { - return PitchLinearCoord(Base::operator-(b)); - } - - CUTLASS_HOST_DEVICE - PitchLinearCoord operator-() const { - return PitchLinearCoord(-at(0), -at(1)); - } - - /// Element-wise multiplication - CUTLASS_HOST_DEVICE - PitchLinearCoord operator*(Base const& b) const { - return PitchLinearCoord(Base::operator*(b)); - } - - /// Element-wise division - CUTLASS_HOST_DEVICE - PitchLinearCoord operator/(Base const& b) const { - return PitchLinearCoord(Base::operator/(b)); - } - - /// In-place addition - CUTLASS_HOST_DEVICE - PitchLinearCoord& operator+=(Base const& b) { - Base::operator+=(b); - return *this; - } - - /// In-place subtraction - CUTLASS_HOST_DEVICE - PitchLinearCoord& operator-=(Base const& b) { - Base::operator-=(b); - return *this; - } - - /// In-place multiplication - CUTLASS_HOST_DEVICE - PitchLinearCoord& operator*=(Base const& b) { - Base::operator*=(b); - return *this; - } - - /// In-place division - CUTLASS_HOST_DEVICE - PitchLinearCoord& operator/=(Base const& b) { - Base::operator/=(b); - return *this; - } -}; +template + using PitchLinearShape = cutlass::PitchLinearShape < Contiguous, Strided >; + using PitchLinearCoord = PitchLinearCoord; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -183,7 +59,7 @@ public: using TensorCoord = PitchLinearCoord; /// Stride vector - using Stride = Coord; + using Stride = Coord; private: // @@ -200,7 +76,7 @@ public: /// Constructor CUTLASS_HOST_DEVICE - PitchLinear(Index ldm = 0): stride_(ldm) { } + PitchLinear(LongIndex ldm = 0): stride_(ldm) { } /// Constructor CUTLASS_HOST_DEVICE @@ -223,8 +99,8 @@ public: CUTLASS_HOST_DEVICE TensorCoord inverse(LongIndex index) const { return make_Coord( - Index(index % stride_[0]), - Index(index / stride_[0]) + TensorCoord::Index(index % stride_[0]), + TensorCoord::Index(index / stride_[0]) ); } @@ -242,13 +118,13 @@ public: /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index stride(int rank) const { + LongIndex stride(int rank) const { return stride_[rank]; } /// Returns the stride of the layout CUTLASS_HOST_DEVICE - Index & stride(int rank) { + LongIndex & stride(int rank) { return stride_[rank]; } diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index 1196b726..e9b9d544 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -101,6 +101,16 @@ public: ): stride_(make_Coord(stride_w, stride_h, stride_n)) { } + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNHWC(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + /// Helper returns a layout to a tightly packed NHWC tensor. CUTLASS_HOST_DEVICE static TensorNHWC packed(TensorCoord const &extent) { @@ -323,6 +333,16 @@ public: ): stride_(make_Coord(stride_w, stride_h, stride_n)) { } + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNCxHWx(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + /// Helper returns a layout to a tightly packed tensor CUTLASS_HOST_DEVICE static TensorNCxHWx packed(TensorCoord const &extent) { @@ -422,6 +442,17 @@ public: ): stride_(make_Coord(stride_w, stride_h, stride_n)) { } + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorCxRSKx(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed tensor CUTLASS_HOST_DEVICE static TensorCxRSKx packed(TensorCoord const &extent) { @@ -524,6 +555,17 @@ public: typename Stride::Index dhwc): stride_(make_Coord(c, wc, hwc, dhwc)) { } + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNDHWC(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2]), + static_cast(stride[3])) + ) { } + /// Helper returns a layout to a tightly packed NHWC tensor. CUTLASS_HOST_DEVICE static TensorNDHWC packed(TensorCoord const &extent) { diff --git a/include/cutlass/matrix.h b/include/cutlass/matrix.h index 971f125e..abbf9dfb 100644 --- a/include/cutlass/matrix.h +++ b/include/cutlass/matrix.h @@ -29,8 +29,10 @@ #pragma once +#if !defined(__CUDACC_RTC__) #include #include +#endif #include "cutlass/cutlass.h" #include "cutlass/array.h" @@ -222,6 +224,7 @@ struct Matrix { return slice_1x2(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x2(v, i, 0); } @@ -803,6 +806,7 @@ struct Matrix { return slice_1x3(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x3(v, i, 0); } @@ -1456,6 +1460,7 @@ struct Matrix { return slice_1x4(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x4(v, i, 0); } @@ -2084,6 +2089,7 @@ struct Matrix { return slice_2x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_2x1(v, 0, j); } @@ -2726,6 +2732,7 @@ struct Matrix { return slice_1x2(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x2(v, i, 0); } @@ -2756,6 +2763,7 @@ struct Matrix { return slice_2x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_2x1(v, 0, j); } @@ -3532,6 +3540,7 @@ struct Matrix { return slice_1x3(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x3(v, i, 0); } @@ -3562,6 +3571,7 @@ struct Matrix { return slice_2x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_2x1(v, 0, j); } @@ -4434,6 +4444,7 @@ struct Matrix { return slice_1x4(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x4(v, i, 0); } @@ -4464,6 +4475,7 @@ struct Matrix { return slice_2x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_2x1(v, 0, j); } @@ -5335,6 +5347,7 @@ struct Matrix { return slice_3x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_3x1(v, 0, j); } @@ -6033,6 +6046,7 @@ struct Matrix { return slice_1x2(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x2(v, i, 0); } @@ -6111,6 +6125,7 @@ struct Matrix { return slice_3x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_3x1(v, 0, j); } @@ -6964,6 +6979,7 @@ struct Matrix { return slice_1x3(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x3(v, i, 0); } @@ -7071,6 +7087,7 @@ struct Matrix { return slice_3x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_3x1(v, 0, j); } @@ -8219,6 +8236,7 @@ struct Matrix { return slice_1x4(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x4(v, i, 0); } @@ -8359,6 +8377,7 @@ struct Matrix { return slice_3x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_3x1(v, 0, j); } @@ -9433,6 +9452,7 @@ struct Matrix { return slice_4x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_4x1(v, 0, j); } @@ -10180,6 +10200,7 @@ struct Matrix { return slice_1x2(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x2(v, i, 0); } @@ -10312,6 +10333,7 @@ struct Matrix { return slice_4x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_4x1(v, 0, j); } @@ -11258,6 +11280,7 @@ struct Matrix { return slice_1x3(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x3(v, i, 0); } @@ -11454,6 +11477,7 @@ struct Matrix { return slice_4x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_4x1(v, 0, j); } @@ -12644,6 +12668,7 @@ struct Matrix { return slice_1x4(i, 0); } + CUTLASS_HOST_DEVICE Matrix &set_row(Matrix const &v, int i = 0) { return set_slice_1x4(v, i, 0); } @@ -12914,6 +12939,7 @@ struct Matrix { return slice_4x1(0, j); } + CUTLASS_HOST_DEVICE Matrix &set_column(Matrix const &v, int j =0) { return set_slice_4x1(v, 0, j); } @@ -14090,20 +14116,6 @@ Matrix operator*(Element s, Matrix -std::ostream & operator<<(std::ostream &out, Matrix const &rhs) { - - for (int i = 0; i < Rows; ++i) { - for (int j = 0; j < Columns; ++j) { - out << (j ? ", " : "") << rhs.at(i, j); - } - out << "\n"; - } - - return out; -} - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/matrix_coord.h b/include/cutlass/matrix_coord.h index dcf25cc6..c435d6b0 100644 --- a/include/cutlass/matrix_coord.h +++ b/include/cutlass/matrix_coord.h @@ -46,6 +46,9 @@ public: /// Base type is a Coord of rank=2 using Base = Coord<2, Index>; + /// LongIndex type + using LongIndex = typename Base::LongIndex; + private: /// Rows dimension @@ -72,6 +75,10 @@ public: CUTLASS_HOST_DEVICE MatrixCoord(Index row, Index column): Base(make_Coord(row, column)) { } + /// Helper to construct from a row and column, which are LongIndex based + CUTLASS_HOST_DEVICE + MatrixCoord(LongIndex row, LongIndex column): Base(make_Coord(Index(row), Index(column))) { } + /// Returns the row of the coordinate CUTLASS_HOST_DEVICE Index const & row() const { return this->at(kRow); } diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 5f00688c..1a00db13 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -28,6 +28,10 @@ */ #pragma once +#if !defined(__CUDACC_RTC_) +#include +#endif + #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -77,30 +81,175 @@ struct NumericConverter { ///////////////////////////////////////////////////////////////////////////////////////////////// // -// Partial specializations for float => int8_t +// Partial specializations for float => int32_t // ///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct NumericConverter { - using result_type = int8_t; +#if defined(__CUDA_ARCH__) +template <> +struct NumericConverter { + + using result_type = int32_t; using source_type = float; - static FloatRoundStyle const round_style = Round; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static result_type convert(source_type const & s) { - result_type result = static_cast(s); - - return result; + return __float2int_rn(s); } - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE result_type operator()(source_type const &s) { return convert(s); } }; +template <> +struct NumericConverter { + + using result_type = int32_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + + CUTLASS_DEVICE + static result_type convert(source_type const & s) { + + return __float2int_rz(s); + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +#elif !defined(__CUDACC_RTC__) + +template <> +struct NumericConverter { + + using result_type = int32_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; + + static result_type convert(source_type const & s) { + std::fesetround(FE_TONEAREST); + return (result_type)std::nearbyint(s); + } + + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +template <> +struct NumericConverter { + + using result_type = int32_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + + static result_type convert(source_type const & s) { + std::fesetround(FE_TOWARDZERO); + return (result_type)std::nearbyint(s); + } + + result_type operator()(source_type const &s) { + return convert(s); + } +}; +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for float => int8_t +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) +template <> +struct NumericConverter { + + using result_type = int8_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; + + CUTLASS_DEVICE + static result_type convert(source_type const & s) { + + int32_t intermediate = __float2int_rn(s); + + return static_cast(intermediate); + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +template <> +struct NumericConverter { + + using result_type = int8_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + + CUTLASS_DEVICE + static result_type convert(source_type const & s) { + + int32_t intermediate = __float2int_rz(s); + + return static_cast(intermediate); + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +#elif !defined(__CUDACC_RTC__) + +template <> +struct NumericConverter { + + using result_type = int8_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; + + static result_type convert(source_type const & s) { + std::fesetround(FE_TONEAREST); + int32_t intermediate = (result_type)std::nearbyint(s); + return static_cast(intermediate); + } + + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +template <> +struct NumericConverter { + + using result_type = int8_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + + static result_type convert(source_type const & s) { + std::fesetround(FE_TOWARDZERO); + int32_t intermediate = (result_type)std::nearbyint(s); + return static_cast(intermediate); + } + + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +#endif + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization for float <= half_t diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index 363997b6..58e1efb1 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -43,7 +43,7 @@ namespace cutlass { /// Defines the size of an element in bits template struct sizeof_bits { - static int const value = sizeof(T) * 8; + static int const value = int(sizeof(T) * 8); }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pitch_linear_coord.h b/include/cutlass/pitch_linear_coord.h new file mode 100644 index 00000000..af2970b5 --- /dev/null +++ b/include/cutlass/pitch_linear_coord.h @@ -0,0 +1,175 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines layout functions used by TensorRef and derived classes for pitch-linear memory. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template defining a shape used by pitch-linear operators +template < + int Contiguous, + int Strided +> +struct PitchLinearShape { + static int const kContiguous = Contiguous; + static int const kStrided = Strided; + static int const kCount = Contiguous * Strided; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Coordinate in pitch-linear space +struct PitchLinearCoord : public Coord<2, int> { +public: + + /// Integer-valued index + using Index = int; + + /// Base type is a Coord of rank=2 + using Base = Coord<2, Index>; + + /// Long integer type + using LongIndex = typename Base::LongIndex; + +private: + + /// Rows dimension + static int const kContiguous = 0; + + /// Columns dimension + static int const kStrided = 1; + +public: + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + PitchLinearCoord() { } + + /// Constructs from Coord<2> + CUTLASS_HOST_DEVICE + PitchLinearCoord(Coord<2, Index> const &coord): Base(coord) { } + + /// Helper to construct from a row and column + CUTLASS_HOST_DEVICE + PitchLinearCoord(Index contiguous_, Index strided_): Base(make_Coord(contiguous_, strided_)) { } + + /// Helper to construct from a row and column based on LongIndex + CUTLASS_HOST_DEVICE + PitchLinearCoord(LongIndex contiguous_, LongIndex strided_) + : Base(make_Coord(Index(contiguous_), Index(strided_))) { } + + /// Returns the contiguous dimension + CUTLASS_HOST_DEVICE + Index const & contiguous() const { return this->at(kContiguous); } + + /// Returns the contiguous dimension + CUTLASS_HOST_DEVICE + Index & contiguous() { return this->at(kContiguous); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index const & strided() const { return this->at(kStrided); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index & strided() { return this->at(kStrided); } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + PitchLinearCoord operator+(Base const& b) const { + return PitchLinearCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + PitchLinearCoord operator-(Base const& b) const { + return PitchLinearCoord(Base::operator-(b)); + } + + CUTLASS_HOST_DEVICE + PitchLinearCoord operator-() const { + return PitchLinearCoord(-at(0), -at(1)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + PitchLinearCoord operator*(Base const& b) const { + return PitchLinearCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + PitchLinearCoord operator/(Base const& b) const { + return PitchLinearCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + PitchLinearCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + PitchLinearCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + PitchLinearCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + PitchLinearCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + diff --git a/include/cutlass/predicate_vector.h b/include/cutlass/predicate_vector.h index 6ef748fb..7ec10933 100644 --- a/include/cutlass/predicate_vector.h +++ b/include/cutlass/predicate_vector.h @@ -129,7 +129,7 @@ struct PredicateVector { static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte; /// Number of storage elements needed - static int const kWordCount = (kBytes + sizeof(Storage) - 1) / sizeof(Storage); + static int const kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage)); private: // diff --git a/include/cutlass/quaternion.h b/include/cutlass/quaternion.h index 67e0634a..df3102b3 100644 --- a/include/cutlass/quaternion.h +++ b/include/cutlass/quaternion.h @@ -30,6 +30,7 @@ #include "cutlass/cutlass.h" #include "cutlass/array.h" +#include "cutlass/real.h" #include "cutlass/coord.h" #include "cutlass/matrix.h" #include "cutlass/fast_math.h" @@ -82,18 +83,27 @@ public: // Methods // - /// Constructs a quaternion + /// Constructs a quaternion q = 0 + CUTLASS_HOST_DEVICE + Quaternion() { + Base::at(kX) = Element(); + Base::at(kY) = Element(); + Base::at(kZ) = Element(); + Base::at(kW) = Element(); + } + + /// Constructs a quaternion q = w + 0*i + 0*j + 0*k CUTLASS_HOST_DEVICE Quaternion( - Element w_ = Element(1) + Element w_ ) { - Base::at(kX) = Element(0); - Base::at(kY) = Element(0); - Base::at(kZ) = Element(0); + Base::at(kX) = Element(); + Base::at(kY) = Element(); + Base::at(kZ) = Element(); Base::at(kW) = w_; } - /// Constructs a quaternion + /// Constructs a quaternion q = w + x*i + y*j + z*k CUTLASS_HOST_DEVICE Quaternion( Element x_, @@ -355,7 +365,21 @@ Quaternion make_Quaternion(Element x, Element y, Element z, Element w) ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Returns the magnitude of the complex number +/// Returns the real part of the quaternion number +template +CUTLASS_HOST_DEVICE +Element const &real(Quaternion const &q) { + return q.w(); +} + +/// Returns the real part of the quaternion number +template +CUTLASS_HOST_DEVICE +Element &real(Quaternion &q) { + return q.w(); +} + +/// Returns the magnitude of the quaternion number template CUTLASS_HOST_DEVICE Element abs(Quaternion const &q) { @@ -599,13 +623,38 @@ Matrix3x1 spinor_rotation_inv( ///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Output operators -// +/// Partial specialization for Quaternion-valued type. +template +struct RealType< Quaternion > { + using Type = T; -template -std::ostream &operator<<(std::ostream &out, Quaternion const &q) { - return out << q.w() << "+i" << q.x() << "+j" << q.y() << "+k" << q.z(); + /// Number of elements + static int const kExtent = Quaternion::kExtent; + +CUTLASS_HOST_DEVICE + static Quaternion from_real(double x) { + return Quaternion(static_cast(x)); + } +}; + +////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +CUTLASS_HOST_DEVICE +cutlass::Quaternion from_real >(double r) { + return cutlass::Quaternion(half_t(r)); +} + +template <> +CUTLASS_HOST_DEVICE +cutlass::Quaternion from_real >(double r) { + return cutlass::Quaternion(float(r)); +} + +template <> +CUTLASS_HOST_DEVICE +cutlass::Quaternion from_real >(double r) { + return cutlass::Quaternion(r); } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/real.h b/include/cutlass/real.h index faa7d92d..83ffcd56 100644 --- a/include/cutlass/real.h +++ b/include/cutlass/real.h @@ -36,6 +36,9 @@ template struct RealType { using Type = T; + /// Number of elements + static int const kExtent = 1; + CUTLASS_HOST_DEVICE static T from_real(double x) { return static_cast(x); diff --git a/include/cutlass/reduction/device/reduce_split_k.h b/include/cutlass/reduction/device/reduce_split_k.h index f8558643..c5b5f6fe 100644 --- a/include/cutlass/reduction/device/reduce_split_k.h +++ b/include/cutlass/reduction/device/reduce_split_k.h @@ -56,6 +56,8 @@ public: using WorkspaceTensorRef = typename ReductionKernel::WorkspaceTensorRef; using OutputTensorRef = typename ReductionKernel::OutputTensorRef; + using StrideIndex = typename ReductionKernel::StrideIndex; + /// Argument structure struct Arguments { diff --git a/include/cutlass/reduction/kernel/reduce_split_k.h b/include/cutlass/reduction/kernel/reduce_split_k.h index 870b94b8..9f189941 100644 --- a/include/cutlass/reduction/kernel/reduce_split_k.h +++ b/include/cutlass/reduction/kernel/reduce_split_k.h @@ -67,6 +67,7 @@ public: using WorkspaceTensorRef = TensorRef; using OutputTensorRef = TensorRef; + using StrideIndex = typename WorkspaceTensorRef::Layout::Stride::Index; using FragmentWorkspace = AlignedArray; using FragmentAccumulator = Array; @@ -145,8 +146,8 @@ public: // Determine CTA position MatrixCoord thread_offset( - int(blockIdx.x) * Shape::kRow + threadIdx.y, - int(blockIdx.y) * Shape::kColumn + threadIdx.x * kElementsPerAccess + MatrixCoord::Index(int(blockIdx.x) * Shape::kRow + threadIdx.y), + MatrixCoord::Index(int(blockIdx.y) * Shape::kColumn + threadIdx.x * kElementsPerAccess) ); // One guard conditional diff --git a/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h b/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h index 1dfe7e7e..8328ea40 100644 --- a/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h +++ b/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h @@ -68,7 +68,7 @@ struct TensorReductionAffineStridedParams { static int const kBatchSize = BatchSize; Coord extent; /// Extent of source tensor - FastDivmodU64 divmod[kRank - 2]; /// FastDivmod by each strided rank + FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank int64_t dst_stride[kReducedRank - 1]; /// stride (units of bytes) - I, J int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K int64_t workspace_stride; /// stride (units of bytes) between workspace @@ -120,7 +120,7 @@ struct TensorReductionAffineStridedParams { reduction_identity(reduction_identity_) { // Initialize divisors for fast div-mod - for (int p = 1; p < kRank - 1; ++p) { + for (int p = 1; p < kRank; ++p) { divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); } @@ -204,7 +204,7 @@ private: uint64_t linear_idx) const { // Decompose into coordinate - coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kReducedRank]); + coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kReducedRank - 1]); // Compute linear offset src_offset = 0; diff --git a/include/cutlass/tensor_coord.h b/include/cutlass/tensor_coord.h index 5c0c6031..5aff10f0 100644 --- a/include/cutlass/tensor_coord.h +++ b/include/cutlass/tensor_coord.h @@ -74,6 +74,11 @@ struct Tensor4DCoord : public Coord<4> { CUTLASS_HOST_DEVICE Tensor4DCoord(Index n, Index h, Index w, Index c): Base(make_Coord(n, h, w, c)) { } + /// Helper to construct from N, H, W, and C, which are LongIndex type + CUTLASS_HOST_DEVICE + Tensor4DCoord(LongIndex n, LongIndex h, LongIndex w, LongIndex c) + : Base(make_Coord(Index(n), Index(h), Index(w), Index(c))) { } + /// Returns the batch of the coordinate CUTLASS_HOST_DEVICE Index const & n() const { return this->at(kN); } @@ -208,6 +213,11 @@ struct Tensor5DCoord : public Coord<5> { CUTLASS_HOST_DEVICE Tensor5DCoord(Index n, Index d, Index h, Index w, Index c): Base(make_Coord(n, d, h, w, c)) { } + /// Helper to construct from N, D, H, W, and C, which are LongIndex type + CUTLASS_HOST_DEVICE + Tensor5DCoord(LongIndex n, LongIndex d, LongIndex h, LongIndex w, LongIndex c) + : Base(make_Coord(Index(n), Index(d), Index(h), Index(w), Index(c))) { } + /// Returns the batch of the coordinate CUTLASS_HOST_DEVICE Index const & n() const { return this->at(kN); } diff --git a/include/cutlass/tensor_ref.h b/include/cutlass/tensor_ref.h index 2782b49f..c375233b 100644 --- a/include/cutlass/tensor_ref.h +++ b/include/cutlass/tensor_ref.h @@ -202,11 +202,17 @@ class TensorRef { // Methods // + /// Constructs a TensorRef with a pointer and layout object. + CUTLASS_HOST_DEVICE + TensorRef(): ptr_(nullptr) { + + } + /// Constructs a TensorRef with a pointer and layout object. CUTLASS_HOST_DEVICE TensorRef( - Element *ptr = nullptr, ///< pointer to start of tensor - Layout const &layout = Layout() ///< layout object containing stride and mapping function + Element *ptr, ///< pointer to start of tensor + Layout const &layout ///< layout object containing stride and mapping function ): ptr_(ptr), layout_(layout) { @@ -286,13 +292,13 @@ class TensorRef { /// Returns the layout object's stride in a given physical dimension CUTLASS_HOST_DEVICE - Index stride(int dim) const { + typename Layout::Stride::Index stride(int dim) const { return layout_.stride().at(dim); } /// Returns the layout object's stride in a given physical dimension CUTLASS_HOST_DEVICE - Index & stride(int dim) { + typename Layout::Stride::Index & stride(int dim) { return layout_.stride().at(dim); } diff --git a/include/cutlass/tensor_view.h b/include/cutlass/tensor_view.h index 333c559a..33fd9d42 100644 --- a/include/cutlass/tensor_view.h +++ b/include/cutlass/tensor_view.h @@ -117,9 +117,7 @@ class TensorView : public TensorRef { /// Constructs a TensorView object CUTLASS_HOST_DEVICE - TensorView(TensorCoord const &extent = TensorCoord()): extent_(extent) { - - } + TensorView() { } /// Constructs a TensorView object CUTLASS_HOST_DEVICE diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h index 67a7f1c7..f77b6079 100644 --- a/include/cutlass/tfloat32.h +++ b/include/cutlass/tfloat32.h @@ -98,7 +98,11 @@ struct alignas(4) tfloat32_t { CUTLASS_HOST_DEVICE explicit tfloat32_t(int x) { float flt = static_cast(x); + #if defined(__CUDA_ARCH__) storage = reinterpret_cast(flt); + #else + std::memcpy(&storage, &flt, sizeof(storage)); + #endif } /// Converts to float @@ -108,8 +112,14 @@ struct alignas(4) tfloat32_t { // Conversions to IEEE single-precision requires clearing dont-care bits // of the mantissa. unsigned bits = (storage & ~0x1fffu); - + + #if defined(__CUDA_ARCH__) return reinterpret_cast(bits); + #else + float flt; + std::memcpy(&flt, &bits, sizeof(flt)); + return flt; + #endif } /// Converts to float @@ -353,8 +363,8 @@ tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) { CUTLASS_HOST_DEVICE tfloat32_t operator-(tfloat32_t const& lhs) { - float x = -reinterpret_cast(lhs); - return reinterpret_cast(x); + float x = -static_cast(lhs); + return static_cast(x); } CUTLASS_HOST_DEVICE diff --git a/include/cutlass/transform/pitch_linear_thread_map.h b/include/cutlass/transform/pitch_linear_thread_map.h index 11285014..3e05605d 100644 --- a/include/cutlass/transform/pitch_linear_thread_map.h +++ b/include/cutlass/transform/pitch_linear_thread_map.h @@ -517,6 +517,9 @@ struct TransposePitchLinearThreadMap { layout::PitchLinearShape; + static_assert(Iterations::kContiguous == 1, + "Contiguous iteration has to be one to reuse the same shared store function with those that don't need transpose"); + static_assert(Iterations::kCount, "Number of iterations must be non-zero"); ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) @@ -595,6 +598,9 @@ struct TransposePitchLinearThreadMapSimt { static_assert(Iterations::kCount, "Number of iterations must be non-zero"); + static_assert(Iterations::kStrided == 1, + "Strided iteration has to be one to reuse the same shared store function with those that don't need transpose"); + /// Shape of access by each thread using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; diff --git a/include/cutlass/transform/thread/transpose.h b/include/cutlass/transform/thread/transpose.h index 3ce1841a..8a93b8a2 100644 --- a/include/cutlass/transform/thread/transpose.h +++ b/include/cutlass/transform/thread/transpose.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -26,6 +26,7 @@ /*! \file \brief Basic copy routines for tensor views */ + #pragma once namespace cutlass { diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h index a6bdca8f..8a058f60 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -47,6 +47,7 @@ #include "cutlass/predicate_vector.h" #include "cutlass/tensor_ref.h" #include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" //////////////////////////////////////////////////////////////////////////////// @@ -58,29 +59,15 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -/// PredicatedTileAccessIterator +/// PredicatedTileAccessIteratorPredicates /// -template -class PredicatedTileAccessIterator; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization of PredicatedTileAccessIterator for pitch-linear data. -/// -template -class PredicatedTileAccessIterator { +class PredicatedTileAccessIteratorPredicates { public: - static_assert( - AdvanceRank == 0 || AdvanceRank == 1, - "Specialization for pitch-linear iterator may along advance along the " - "contiguous(rank=0) or strided(rank=1) dimension."); - using Shape = Shape_; using Element = Element_; - using Layout = layout::PitchLinear; + using Layout = Layout_; static int const kAdvanceRank = AdvanceRank; using ThreadMap = ThreadMap_; using AccessType = AccessType_; @@ -88,16 +75,11 @@ class PredicatedTileAccessIterator; - using TensorView = TensorView; using TensorCoord = typename Layout::TensorCoord; - using Pointer = Element *; - using NonConstPointer = typename platform::remove_const::type *; - static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; - - static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), "Vectors implied by the thread map must be divisible by the access type."); static int const kPredicatesPerByte = 4; @@ -106,7 +88,7 @@ class PredicatedTileAccessIterator; - /// Parameters object is precomputed state and is host-constructible - class Params { - public: - friend PredicatedTileAccessIterator; - - private: - /// stride of pitch-linear layout (units of Element) - int stride_; - /// amount (in byte) to increment pointer to move to next access along - /// strided dimension - LongIndex inc_strided_; - /// amount (in byte) to increment pointer from last access to first access - /// of next tile - LongIndex inc_next_; - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - LongIndex inc_advance_; - - public: - - // Default ctor - CUTLASS_HOST_DEVICE - Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const &layout) : stride_(layout.stride(0)) { - inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * - sizeof_bits::value / 8; - - if (kAdvanceRank) { - // advance along strided dimension - inc_advance_ = - Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; - } else { - // advance along contiguous dimension - inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; - } - - inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * - ThreadMap::Delta::kStrided * LongIndex(stride_) * - sizeof_bits::value / 8; - }; - }; - - private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char *; - - private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const ¶ms_; - - /// Internal pointer to first access of tile - BytePointer pointer_; - +// private: /// Guard predicates uint32_t predicates_[kPredicateWordCount]; @@ -189,9 +112,6 @@ class PredicatedTileAccessIterator( - const_cast(pointer))), - extent_(extent), - is_residue_tile_(true) { - + void set_predicates(int thread_id, TensorCoord const &threadblock_offset) { + TensorCoord residue_extent; if (kAdvanceRank) { - Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; if (!residue_size) { residue_size = Shape::kStrided; } @@ -292,7 +194,7 @@ class PredicatedTileAccessIterator::value * pointer_offset / 8; - } - - /// Advances an iterator along logical dimensions of matrix in units of whole tiles - CUTLASS_DEVICE - void add_tile_offset( - TensorCoord const &tile_offset) { - if (is_residue_tile_) { - - thread_offset_ += residue_offset_; - - Layout layout(params_.stride_); - add_pointer_offset(layout(residue_offset_)); - - compute_predicates_(extent_, true); - - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); - pointer_ += Shape::kContiguous * tile_offset.contiguous(); - } else { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); - pointer_ += Shape::kStrided * tile_offset.strided(); - } - } else { - if (kAdvanceRank) { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); - pointer_ += Shape::kContiguous * tile_offset.contiguous(); - } else { - pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); - pointer_ += Shape::kStrided * tile_offset.strided(); - } - } - is_residue_tile_ = false; - } - - /// Returns a pointer - CUTLASS_HOST_DEVICE - AccessType *get() const { - return reinterpret_cast( - pointer_ + - iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; - } - /// Increment and return an instance to self. CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator &operator++() { - - ++iteration_vector_; - if (iteration_vector_ < kAccessesPerVector) { - return *this; - } - - iteration_vector_ = 0; - ++iteration_contiguous_; - - if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { - return *this; - } - - // Enter here only if (iteration_contiguous_ == - // ThreadMap::Iteration::kContiguous) - iteration_contiguous_ = 0; - ++iteration_strided_; - - if (iteration_strided_ < ThreadMap::Iterations::kStrided) { - pointer_ += params_.inc_strided_; - return *this; - } - - // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) - // which means we enter the next tile. - iteration_strided_ = 0; - - // advance to next tile - pointer_ += params_.inc_next_; - - // now return to start tile - if the iterator is subsequently advanced, this - // subtraction as well as the subsequent integer addition are both elided by - // the compiler. - pointer_ -= params_.inc_advance_; + PredicatedTileAccessIteratorPredicates &operator++() { return *this; } - /// Increment and return an instance to self. - CUTLASS_HOST_DEVICE - PredicatedTileAccessIterator operator++(int) { - PredicatedTileAccessIterator self(*this); - operator++(); - return self; - } - /// Clears the predicate set efficiently CUTLASS_HOST_DEVICE void clear_mask() { @@ -492,9 +297,288 @@ class PredicatedTileAccessIterator +class PredicatedTileAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for pitch-linear data. +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, Element, Layout, AdvanceRank, ThreadMap, AccessType>; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : + Base(layout.stride(0), + MakePredicatedTileAccessIteratorDesc()() + ) { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + is_residue_tile_(true) { + + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + if (is_residue_tile_) { + + the_predicates.thread_offset_ += the_predicates.residue_offset_; + + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.residue_offset_)); + + the_predicates.compute_predicates_(the_predicates.extent_, true); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast( + pointer_ + + the_predicates.iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask() { + the_predicates.clear_mask(); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } + }; //////////////////////////////////////////////////////////////////////////////// @@ -560,6 +644,11 @@ class PredicatedTileAccessIterator +class PredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap, AccessType>; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIterator; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + + // Default ctor + CUTLASS_HOST_DEVICE + Params(): stride_(0), inc_contiguous_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = inc_strided_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = + Shape::kStrided * LongIndex(stride_[1]) * sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + is_residue_tile_(true) { + + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { the_predicates.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + if (is_residue_tile_) { + + the_predicates.thread_offset_ += the_predicates.residue_offset_; + + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.residue_offset_)); + + the_predicates.compute_predicates_(the_predicates.extent_, true); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1] - 1); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0] - 1); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } else { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(pointer_) + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask() { the_predicates.clear_mask(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { the_predicates.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { the_predicates.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { the_predicates.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for affine rank 2 column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::AffineRankN<2>, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset(make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask() { iterator_.clear_mask(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for affine rank-2 row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::AffineRankN<2>, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset(make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask() { iterator_.clear_mask(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Specialization of PredicatedTileAccessIterator for column-major interleaved data. /// It is mapped to the congruous layout. /// @@ -916,6 +1666,10 @@ class PredicatedTileAccessIterator; using TensorView = TensorView; @@ -108,51 +110,31 @@ class PredicatedTileAccessIterator2dThreadTile; - /// Parameters object is precomputed state and is host-constructible - class Params { + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + public: friend PredicatedTileAccessIterator2dThreadTile; - private: - /// stride of pitch-linear layout (units of Element) - int stride_; - /// amount (in byte) to increment pointer to move to next access along - /// strided dimension - int inc_strided_; - /// amount (in byte) to increment pointer from last access to first access - /// of next tile - int inc_next_; - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - int inc_advance_; - - public: + using Base = PredicatedTileAccessIteratorParams; // Default ctor CUTLASS_HOST_DEVICE - Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } + Params() { } /// Construct the Params object given a pitch-linear tensor's layout CUTLASS_HOST_DEVICE - Params(Layout const &layout) : stride_(layout.stride(0)) { + Params(Layout const &layout) : + Base(layout.stride(0), + MakePredicatedTileAccessIteratorDesc()() + ) { } - inc_strided_ = - (stride_ * ThreadMap::Delta::kStrided) * int(sizeof(Element)); - - if (kAdvanceRank) { - // advance along strided dimension - inc_advance_ = Shape::kStrided * stride_ * int(sizeof(Element)); - } else { - // advance along contiguous dimension - inc_advance_ = Shape::kContiguous * int(sizeof(Element)); - } - - inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) * - ThreadMap::Delta::kStrided * stride_ * - int(sizeof(Element)); - }; + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } }; + private: /// Internal pointer type permits fast address arithmetic using BytePointer = char *; @@ -537,7 +519,12 @@ class PredicatedTileAccessIterator2dThreadTile + struct MakePredicatedTileAccessIteratorDesc; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for pitch-linear data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap> { + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return PredicatedTileAccessIteratorDesc( + sizeof_bits::value, + AdvanceRank, + {Shape::kContiguous, Shape::kStrided}, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ); +} + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for column-major data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::ColumnMajor, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); +} + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::RowMajor, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); +} + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap, int InterleavedK> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::ColumnMajorInterleaved, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + static int const kInterleavedK = InterleavedK; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); +} + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for roww-major interleaved data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap, int InterleavedK> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::RowMajorInterleaved, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + static int const kInterleavedK = InterleavedK; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); +} + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Parameters struct +// + +struct PredicatedTileAccessIteratorParams { + + using Index = int32_t; + using LongIndex = int64_t; + + // + // Data members + // + /// stride of pitch-linear layout (units of Element) + LongIndex stride_; + /// amount (in byte) to increment pointer to move to next access along + /// strided dimension + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Status initialize(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { + + stride_ = stride; + + inc_strided_ = (LongIndex(stride_) * desc.threadmap_delta.strided()) * + desc.element_size_bits / 8; + + if (desc.advance_rank) { + // advance along strided dimension + inc_advance_ = + desc.threadblock_shape.strided() * LongIndex(stride_) * desc.element_size_bits / 8; + } else { + // advance along contiguous dimension + inc_advance_ = desc.threadblock_shape.contiguous() * desc.element_size_bits / 8; + } + + inc_next_ = inc_advance_ - LongIndex(desc.threadmap_iterations.strided() - 1) * + desc.threadmap_delta.strided() * LongIndex(stride_) * + desc.element_size_bits / 8; + + return Status::kSuccess; + + } + + CUTLASS_HOST_DEVICE + Status initialize(Index stride, PredicatedTileAccessIteratorDesc desc) { + return initialize(LongIndex(stride), desc); + } + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorParams() { + initialize(LongIndex(0), PredicatedTileAccessIteratorDesc()); + } + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorParams(Index stride, PredicatedTileAccessIteratorDesc desc) { + initialize(stride, desc); + } + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorParams(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { + initialize(stride, desc); + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/transform/threadblock/predicated_tile_iterator.h b/include/cutlass/transform/threadblock/predicated_tile_iterator.h index adc4cb15..dc562dc7 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_tile_iterator.h @@ -189,6 +189,8 @@ class PredicatedTileIterator +class PredicatedTileIterator, AdvanceRank, + ThreadMap_, AccessSize> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + using AccessType = AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIterator; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + + friend PredicatedTileIterator; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) { } + + CUTLASS_HOST_DEVICE + Params() { } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask() { address_iterator_.clear_mask(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for affine rank 2 column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class PredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) { + + } + }; + +private: + + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) + ) { } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask() { + iterator_.clear_mask(); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for affine rank 2 row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class PredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + +private: + + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) + ) { } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask() { + iterator_.clear_mask(); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + + /// Specialization of PredicatedTileIterator for interleaved data. It is mapped /// to the congruous layout. /// @@ -854,6 +1518,11 @@ class PredicatedTileIterator; using TensorCoord = typename Layout::TensorCoord; @@ -89,7 +90,7 @@ class RegularTileAccessIterator< // /// Stride value - Index stride_; + StrideIndex stride_; /// Internal pointer to first access of tile AccessType *pointer_; diff --git a/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h b/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h index e0c44b1c..76d83b2c 100644 --- a/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h +++ b/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h @@ -76,6 +76,7 @@ class RegularTileAccessIterator< using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; using TensorRef = TensorRef; using TensorCoord = typename Layout::TensorCoord; @@ -107,7 +108,7 @@ class RegularTileAccessIterator< // /// Stride value - Index stride_; + StrideIndex stride_; /// Internal pointer to first access of tile AccessType *pointer_[Detail::kPointerCount]; @@ -445,6 +446,7 @@ class RegularTileAccessIterator; using TensorCoord = typename Layout::TensorCoord; @@ -492,7 +494,7 @@ class RegularTileAccessIterator; using TensorCoord = typename Layout::TensorCoord; @@ -107,7 +108,7 @@ class RegularTileAccessIterator< // /// Stride value - Index stride_; + StrideIndex stride_; /// Internal pointer to first access of tile AccessType *pointer_; @@ -437,6 +438,7 @@ class RegularTileAccessIterator< using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; using TensorRef = TensorRef; using TensorCoord = typename Layout::TensorCoord; @@ -471,7 +473,7 @@ class RegularTileAccessIterator< // /// Stride value - Index stride_; + StrideIndex stride_; /// Internal pointer to first access of tile AccessType *pointer_; @@ -811,6 +813,7 @@ class RegularTileAccessIterator< using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; using TensorRef = TensorRef; using TensorCoord = typename Layout::TensorCoord; @@ -844,7 +847,7 @@ class RegularTileAccessIterator< // /// Stride value - Index stride_; + StrideIndex stride_; /// Internal pointer to first access of tile AccessType *pointer_; @@ -1175,6 +1178,7 @@ class RegularTileAccessIterator< using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; using TensorRef = TensorRef; using TensorCoord = typename Layout::TensorCoord; @@ -1211,7 +1215,7 @@ class RegularTileAccessIterator< // /// Stride value - Index stride_; + StrideIndex stride_; /// Internal pointer to first access of tile AccessType *pointer_; diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h index 831131f0..e9d09351 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h @@ -70,11 +70,14 @@ public: using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; using TensorRef = TensorRef; using TensorCoord = typename Layout::TensorCoord; using Fragment = Array; + + using AccessType = AlignedArray; static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, "Advance rank may only be along the contiguous or strided dimensions."); @@ -84,8 +87,6 @@ private: // // Types // - - using AccessType = AlignedArray; // // Data members @@ -95,7 +96,7 @@ private: uint8_t *pointer_; /// Stride quantity - Index stride_; + StrideIndex stride_; /// Amount to increment pointer along strided dimension Index increment_strided_; @@ -242,6 +243,30 @@ public: (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; add_pointer_offset(offset); } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { +#if 0 + AccessType *access_ptr = pointer_[iteration_strided_ & 1]; + int stride_idx = (iteration_strided_ & ~1); + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + return reinterpret_cast(access_byte_ptr + byte_offset_); +#endif + return reinterpret_cast(pointer_); + } + }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -281,6 +306,8 @@ public: kAlignment >; + using AccessType = typename Underlying::AccessType; + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, "Advance rank may only be along the row or column dimensions."); @@ -364,6 +391,17 @@ public: iterator_.add_tile_offset({coord.column(), coord.row()}); } + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return iterator_.get(); + } + }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -402,6 +440,8 @@ public: ThreadMap >; + using AccessType = typename Underlying::AccessType; + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, "Advance rank may only be along the row or column dimensions."); @@ -485,6 +525,17 @@ public: iterator_.add_tile_offset({coord.row(), coord.column()}); } + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return iterator_.get(); + } + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h index abfba6b8..2469e8d2 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h @@ -79,6 +79,7 @@ public: using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; using TensorRef = TensorRef; using TensorCoord = typename Layout::TensorCoord; @@ -104,13 +105,13 @@ private: uint8_t *pointer_; /// Stride quantity - Index stride_; + StrideIndex stride_; /// Amount to increment pointer along strided dimension - Index increment_strided_; + LongIndex increment_strided_; /// Amount to advance pointer between tiles - Index increment_advance_; + LongIndex increment_advance_; public: diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h b/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h index 0d2bbeea..03302a62 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h @@ -85,6 +85,7 @@ public: using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; using TensorRef = TensorRef; using TensorCoord = typename Layout::TensorCoord; @@ -123,7 +124,7 @@ private: // /// Stride value - Index stride_; + StrideIndex stride_; /// Internal pointer to first access of tile AccessType * pointer_[Detail::kPointerCount]; @@ -557,6 +558,7 @@ public: using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; using TensorRef = TensorRef; using TensorCoord = typename Layout::TensorCoord; @@ -595,7 +597,7 @@ private: // /// Stride value - Index stride_; + StrideIndex stride_; /// Internal pointer to first access of tile AccessType * pointer_[Detail::kPointerCount]; diff --git a/include/cutlass/uint128.h b/include/cutlass/uint128.h index cfcb696e..c70e93fb 100644 --- a/include/cutlass/uint128.h +++ b/include/cutlass/uint128.h @@ -33,13 +33,14 @@ #include #else #include +#include #include #include #include #endif #include "cutlass/cutlass.h" - +#include "cutlass/numeric_types.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -62,15 +63,18 @@ namespace cutlass { struct uint128_t { /// Size of one part of the uint's storage in bits - int const kPartSize = sizeof(uint64_t) * 8; + int const kPartSize = sizeof_bits::value; + + struct hilo { + uint64_t lo; + uint64_t hi; + + CUTLASS_HOST_DEVICE hilo(uint64_t lo_, uint64_t hi_):lo(lo_), hi(hi_) {} + }; // Use a union to store either low and high parts or, if present, a built-in 128b integer type. union { - - struct { - uint64_t lo; - uint64_t hi; - }; + struct hilo hilo_; #if defined(CUTLASS_UINT128_NATIVE) unsigned __int128 native; @@ -83,15 +87,15 @@ struct uint128_t { /// Default ctor CUTLASS_HOST_DEVICE - uint128_t(): lo(0), hi(0) { } + uint128_t(): hilo_(0, 0) { } /// Constructor from uint64 CUTLASS_HOST_DEVICE - uint128_t(uint64_t lo_): lo(lo_), hi(0) { } + uint128_t(uint64_t lo_): hilo_(lo_, 0) { } /// Constructor from two 64b unsigned integers CUTLASS_HOST_DEVICE - uint128_t(uint64_t lo_, uint64_t hi_): lo(lo_), hi(hi_) { + uint128_t(uint64_t lo_, uint64_t hi_): hilo_(lo_, hi_) { } @@ -103,7 +107,7 @@ struct uint128_t { /// Lossily cast to uint64 CUTLASS_HOST_DEVICE explicit operator uint64_t() const { - return lo; + return hilo_.lo; } CUTLASS_HOST_DEVICE @@ -111,7 +115,8 @@ struct uint128_t { #if defined(__CUDA_ARCH__) asm volatile (" brkpt;\n"); #else - throw std::runtime_error("Not yet implemented."); + // throw std::runtime_error("Not yet implemented."); + abort(); #endif } @@ -122,8 +127,8 @@ struct uint128_t { #if defined(CUTLASS_UINT128_NATIVE) y.native = native + rhs.native; #else - y.lo = lo + rhs.lo; - y.hi = hi + rhs.hi + (!y.lo && (rhs.lo)); + y.hilo_.lo = hilo_.lo + rhs.hilo_.lo; + y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (!y.hilo_.lo && (rhs.hilo_.lo)); #endif return y; } @@ -135,8 +140,8 @@ struct uint128_t { #if defined(CUTLASS_UINT128_NATIVE) y.native = native - rhs.native; #else - y.lo = lo - rhs.lo; - y.hi = hi - rhs.hi - (rhs.lo && y.lo > lo); + y.hilo_.lo = hilo_.lo - rhs.hilo_.lo; + y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo); #endif return y; } @@ -149,11 +154,11 @@ struct uint128_t { y.native = native * rhs; #elif defined(CUTLASS_INT128_ARITHMETIC) // Multiply by the low part - y.lo = _umul128(lo, rhs, &y.hi); + y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); // Add the high part and ignore the overflow uint64_t overflow; - y.hi += _umul128(hi, rhs, &overflow); + y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); #else // TODO - not implemented exception(); @@ -170,7 +175,7 @@ struct uint128_t { #elif defined(CUTLASS_INT128_ARITHMETIC) // implemented using MSVC's arithmetic intrinsics uint64_t remainder = 0; - quotient = _udiv128(hi, lo, divisor, &remainder); + quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else // TODO - not implemented exception(); @@ -186,7 +191,7 @@ struct uint128_t { remainder = uint64_t(native % divisor); #elif defined(CUTLASS_INT128_ARITHMETIC) // implemented using MSVC's arithmetic intrinsics - (void)_udiv128(hi, lo, divisor, &remainder); + (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else // TODO - not implemented exception(); @@ -203,7 +208,7 @@ struct uint128_t { remainder = uint64_t(native % divisor); #elif defined(CUTLASS_INT128_ARITHMETIC) // implemented using MSVC's arithmetic intrinsics - quotient = _udiv128(hi, lo, divisor, &remainder); + quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else // TODO - not implemented exception(); @@ -218,12 +223,12 @@ struct uint128_t { return *this; } else if (sh >= kPartSize) { - return uint128_t(0, lo << (sh - kPartSize)); + return uint128_t(0, hilo_.lo << (sh - kPartSize)); } else { return uint128_t( - (lo << sh), - (hi << sh) | uint64_t(lo >> (kPartSize - sh)) + (hilo_.lo << sh), + (hilo_.hi << sh) | uint64_t(hilo_.lo >> (kPartSize - sh)) ); } } @@ -235,12 +240,12 @@ struct uint128_t { return *this; } else if (sh >= kPartSize) { - return uint128_t((hi >> (sh - kPartSize)), 0); + return uint128_t((hilo_.hi >> (sh - kPartSize)), 0); } else { return uint128_t( - (lo >> sh) | (hi << (kPartSize - sh)), - (hi >> sh) + (hilo_.lo >> sh) | (hilo_.hi << (kPartSize - sh)), + (hilo_.hi >> sh) ); } } diff --git a/media/docs/efficient_gemm.md b/media/docs/efficient_gemm.md index a8374fd8..afdbe54b 100644 --- a/media/docs/efficient_gemm.md +++ b/media/docs/efficient_gemm.md @@ -24,7 +24,7 @@ for (int cta_n = 0; cta_n < GemmN; cta_n += CtaTileN) { // f for (int warp_n = 0; warp_n < CtaTileN; warp_n += WarpTileN) { // for each warp_y } warp-level parallelism for (int warp_m = 0; warp_m < CtaTileM; warp_m += WarpTileM) { // for each warp_x } // - for (int warp_k = 0; warp_k < CtaTileK; warp_k += MmaK) { // fully unroll across CtaTileK + for (int warp_k = 0; warp_k < CtaTileK; warp_k += WarpTileK) { // fully unroll across CtaTileK // - one iteration of this loop is one "k Group" // for (int mma_k = 0; mma_k < WarpTileK; mma_k += MmaK) { // for each mma instruction } instruction-level parallelism diff --git a/media/docs/utilities.md b/media/docs/utilities.md index 78285c84..7df5e207 100644 --- a/media/docs/utilities.md +++ b/media/docs/utilities.md @@ -211,7 +211,7 @@ int main() { ``` -`TensorFillRandomGaussian()` for initializing elements to a random Gaussian distribution. +`TensorFillRandomGaussian()` for initializing elements to a random gaussian distribution. The device-side implementation uses CURAND to generate random numbers. ```c++ #include diff --git a/test/unit/conv/device/CMakeLists.txt b/test/unit/conv/device/CMakeLists.txt index e60c232a..bfc8db05 100644 --- a/test/unit/conv/device/CMakeLists.txt +++ b/test/unit/conv/device/CMakeLists.txt @@ -141,6 +141,9 @@ cutlass_test_unit_add_executable( conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu + + conv2d_fprop_with_broadcast_sm75.cu + conv2d_fprop_with_reduction_sm75.cu ) if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) @@ -158,15 +161,18 @@ if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) cutlass_test_unit_add_executable( cutlass_test_unit_conv_device_tensorop_f32_sm80 - + conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu + + # Strided Dgrad + conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu ) - + # Conv2d - TF32 input, F32 output, F32 accumulation cutlass_test_unit_add_executable( diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu index ba53d6f7..481ff521 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu @@ -71,7 +71,8 @@ TEST(SM50_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_s cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity >::Kernel; using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu index dc3f9d50..dcfd3a57 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu @@ -38,140 +38,6 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 64x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 128x128_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, @@ -208,52 +74,7 @@ TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_s cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::IteratorAlgorithm::kAnalytic, cutlass::conv::StrideSupport::kUnity >::Kernel; diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu index e3eb0736..320e6789 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu @@ -69,7 +69,8 @@ TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tens cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity >::Kernel; using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu index ff512c02..364f3b98 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu @@ -66,7 +66,9 @@ TEST(SM70_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, - cutlass::arch::OpMultiplyAdd + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity >::Kernel; using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu index 212290cb..a4ed04f8 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu @@ -67,7 +67,9 @@ TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, - cutlass::arch::OpMultiplyAdd + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity >::Kernel; using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu index b1fc52f4..069b21f6 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -36,88 +36,6 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, - 128x128_32x3_64x64x32) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::half_t; - using ElementB = cutlass::half_t; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, cutlass::layout::TensorNHWC, - ElementB, cutlass::layout::TensorNHWC, - ElementC, cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 3, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic, - cutlass::conv::StrideSupport::kStrided - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, - 128x128_32x3_64x64x32) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::half_t; - using ElementB = cutlass::half_t; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, cutlass::layout::TensorNHWC, - ElementB, cutlass::layout::TensorNHWC, - ElementC, cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 3, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic, - cutlass::conv::StrideSupport::kUnity - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); -} - //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, 128x128_32x3_64x64x32) { @@ -281,6 +199,5 @@ TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ten /// Run all unit test sizes with device-level Conv2d instance EXPECT_TRUE(test::conv::device::TestAllConv2d()); } - //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu index 542e1e6b..758fe12c 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu @@ -37,95 +37,6 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 64x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, @@ -162,107 +73,7 @@ TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - test::conv::device::Conv2dProblemVector user_size; - - user_size.push_back(cutlass::conv::Conv2dProblemSize( - {1, 8, 8, 4}, // input size (NHWC) - {8, 1, 1, 4}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d(user_size)); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 128x128_8x4_64x32x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 8>, - cutlass::gemm::GemmShape<64, 32, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::IteratorAlgorithm::kAnalytic, cutlass::conv::StrideSupport::kUnity >::Kernel; @@ -273,6 +84,7 @@ TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_sim } + //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, 128x128_8x4_64x32x8) { diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu index 9ef2c7f6..13c283b3 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu @@ -130,93 +130,3 @@ TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_s } //////////////////////////////////////////////////////////////////////////////// -TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 64x64_8x2_32x32x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - cutlass::gemm::GemmShape<64, 64, 8>, - cutlass::gemm::GemmShape<32, 32, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, - 2, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM50_Device_Conv2d_Fprop_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 32x64_8x2_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 32, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, - 2, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kOptimized - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - - diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu index baece322..9e19e921 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu @@ -37,184 +37,6 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 64x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 128x128_8x4_64x32x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 8>, - cutlass::gemm::GemmShape<64, 32, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kOptimized - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); -} -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 128x128_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, 128x128_8x4_64x32x8) { @@ -260,50 +82,6 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_s } -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kOptimized - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, @@ -348,50 +126,6 @@ TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_ /// Run all unit test sizes with device-level Conv2d instance EXPECT_TRUE(test::conv::device::TestAllConv2d()); } -//////////////////////////////////////////////////////////////////////////////// - -TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 64x64_8x3_64x32x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 8>, - cutlass::gemm::GemmShape<64, 32, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 3, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kOptimized - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); -} //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu index 3637fe8c..cebeef52 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu @@ -37,96 +37,6 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 64x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, 128x128_8x4_32x64x8) { @@ -167,106 +77,6 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - test::conv::device::Conv2dProblemVector user_size; - - user_size.push_back(cutlass::conv::Conv2dProblemSize( - {1, 8, 8, 4}, // input size (NHWC) - {8, 1, 1, 4}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d(user_size)); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 128x128_8x4_64x32x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 8>, - cutlass::gemm::GemmShape<64, 32, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized - >::Kernel; - - using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; - /// Run all unit test sizes with device-level Conv2d instance EXPECT_TRUE(test::conv::device::TestAllConv2d()); diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu new file mode 100755 index 00000000..76a8e017 --- /dev/null +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu @@ -0,0 +1,221 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_testbed.h" + + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, + 16x32_8x2_16x16x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::Quaternion; + using ElementB = cutlass::Quaternion; + using ElementC = cutlass::Quaternion; + using ElementAccumulator = cutlass::Quaternion; + using ElementCompute = cutlass::Quaternion; + + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + cutlass::gemm::GemmShape<16, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); + +} + + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, + 16x64_8x2_8x32x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::Quaternion; + using ElementB = cutlass::Quaternion; + using ElementC = cutlass::Quaternion; + using ElementAccumulator = cutlass::Quaternion; + using ElementCompute = cutlass::Quaternion; + + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + cutlass::gemm::GemmShape<16, 64, 8>, + cutlass::gemm::GemmShape<8, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, + 32x32_8x2_16x16x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::Quaternion; + using ElementB = cutlass::Quaternion; + using ElementC = cutlass::Quaternion; + using ElementAccumulator = cutlass::Quaternion; + using ElementCompute = cutlass::Quaternion; + + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM50_Device_Conv2d_Fprop_Optimized_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, + 16x32_8x2_16x16x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::Quaternion; + using ElementB = cutlass::Quaternion; + using ElementC = cutlass::Quaternion; + using ElementAccumulator = cutlass::Quaternion; + using ElementCompute = cutlass::Quaternion; + + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + cutlass::gemm::GemmShape<16, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); + +} + diff --git a/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu new file mode 100644 index 00000000..e8766730 --- /dev/null +++ b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu @@ -0,0 +1,90 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_with_broadcast_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + cutlass::half_t, + float, + float, + cutlass::half_t, + cutlass::half_t, + 8, + cutlass::epilogue::thread::GELU_taylor + >; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2dWithBroadcast()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_fprop_with_reduction_sm75.cu b/test/unit/conv/device/conv2d_fprop_with_reduction_sm75.cu new file mode 100644 index 00000000..109980bc --- /dev/null +++ b/test/unit/conv/device/conv2d_fprop_with_reduction_sm75.cu @@ -0,0 +1,88 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/epilogue/thread/linear_combination_with_elementwise.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_with_reduction_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +TEST(SM75_Device_Conv2d_Fprop_With_Reduction_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationWithElementwise< + float, + float, + cutlass::half_t, + cutlass::half_t, + 8 + >; + + /// Device-level Conv2d instance + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithReduction< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + EpilogueOutputOp, + cutlass::plus, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2dWithReduction()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_problems.h b/test/unit/conv/device/conv2d_problems.h index c532894e..095fed79 100644 --- a/test/unit/conv/device/conv2d_problems.h +++ b/test/unit/conv/device/conv2d_problems.h @@ -161,7 +161,7 @@ struct TestbedConv2dProblemSizes { void initialize_conv2d_default_sizes() { //////////////////////////////////////////////////////////////////////////////////////////// - // Very Small input size (1x8x8xminimum_channel_size), filter size (3x3 - 7x7), stride (1,1) + // Small input size x stride (1,1) // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} //////////////////////////////////////////////////////////////////////////////////////////// @@ -229,6 +229,58 @@ struct TestbedConv2dProblemSizes { {1, 1} // dilation (dilation_h, dilation_w) )); + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (2,2) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 11, 11, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 11, 11, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 13, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, minimum_channel_size}, // input size (NHWC) + {8, 2, 2, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 5, 5, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, 8}, // input size (NHWC) + {8, 3, 3, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + //////////////////////////////////////////////////////////////////////////////////// // Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) //////////////////////////////////////////////////////////////////////////////////// @@ -239,7 +291,15 @@ struct TestbedConv2dProblemSizes { {1, 1}, // stride (stride_h, stride_w) {1, 1} // dilation (dilation_h, dilation_w) )); - + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 19, 37, 160}, // input size (NHWC) + {224, 3, 3, 160}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( {1, 16, 16, 160}, // input size (NHWC) {224, 2, 3, 160}, // filter size (KRSC) @@ -284,16 +344,8 @@ struct TestbedConv2dProblemSizes { )); //////////////////////////////////////////////////////////////////////////////////// - // Medium input size (1x16x16x128), filter size (1x1, 3,x3, 5x5), stride (2, 2) - //////////////////////////////////////////////////////////////////////////////////// - conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 19, 37, 160}, // input size (NHWC) - {224, 3, 3, 160}, // filter size (KRSC) - {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) - {2, 2}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - + // Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) + //////////////////////////////////////////////////////////////////////////////////// conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( {1, 16, 16, 288}, // input size (NHWC) {160, 5, 5, 288}, // filter size (KRSC) @@ -302,6 +354,61 @@ struct TestbedConv2dProblemSizes { {1, 1} // dilation (dilation_h, dilation_w) )); + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 55, 55, 256}, // input size (NHWC) + {512, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 80, 80, 32}, // input size (NHWC) + {64, 5, 5, 32}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 224, 224, 8}, // input size (NHWC) + {64, 7, 7, 8}, // filter size (KRSC) + {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size stride (3, 3), filter (3, 3), non-default padding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size *mixed* stride (1, 2) and (2, 1), + // filter (3, 3), default padding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + ///////////////////////////////////////////////////////////////////////////// // Additional input size ///////////////////////////////////////////////////////////////////////////// @@ -347,15 +454,15 @@ struct TestbedConv2dProblemSizes { #if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 124, 224, 96}, // input size (NHWC) - {24, 7, 7, 96}, // filter size (KRSC) - {1, 229, 129, 32} // output size (NPQK) + {1, 124, 224, 96}, // input size (NHWC) + {24, 7, 7, 96}, // filter size (KRSC) + {1, 229, 129, 32} // output size (NPQK) )); conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( - {1, 233, 35, 48}, // input size (NHWC) - {24, 7, 5, 48}, // filter size (KRSC) - {1, 233, 35, 24} // output size (NPQK) + {1, 233, 35, 48}, // input size (NHWC) + {24, 7, 5, 48}, // filter size (KRSC) + {1, 233, 35, 24} // output size (NPQK) )); #endif diff --git a/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu new file mode 100644 index 00000000..53eca421 --- /dev/null +++ b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -0,0 +1,187 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// +// Strided Dgrad (Analytic) +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, + 128x128_32x3_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + + test::conv::device::Conv2dProblemVector problem_size_list; + +#if 0 // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 55, 55, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + +#endif + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, + 128x256_32x3_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, + 128x256_64x3_64x64x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/conv2d_testbed.h b/test/unit/conv/device/conv2d_testbed.h index 9b94a4db..f4502cf9 100644 --- a/test/unit/conv/device/conv2d_testbed.h +++ b/test/unit/conv/device/conv2d_testbed.h @@ -81,7 +81,7 @@ public: >; using ReductionDevice = cutlass::reduction::device::ReduceSplitK; - + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; public: @@ -161,7 +161,7 @@ public: initialize_tensor(tensor_A.host_view(), init_A, seed); initialize_tensor(tensor_B.host_view(), init_B, seed * 17); initialize_tensor(tensor_C.host_view(), init_C, seed * 39); - + tensor_A.sync_device(); tensor_B.sync_device(); tensor_C.sync_device(); @@ -214,7 +214,7 @@ public: #if 0 //display conv2d problem size for debugging std::cout << problem_size << std::endl - << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl << std::endl; #endif @@ -262,7 +262,7 @@ public: if (status != cutlass::Status::kSuccess) { return false; } - + // run conv2d operator status = conv2d_op(); @@ -271,6 +271,7 @@ public: return false; } + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { // configure parallel reduction operator @@ -280,10 +281,20 @@ public: cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), problem_size.split_k_slices, cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), - {reinterpret_cast (workspace.get()), tensor_C.stride(Conv2d::ImplicitGemmKernel::kTensorCStrideIdx)}, - {tensor_D_computed.device_data(), tensor_C.stride(Conv2d::ImplicitGemmKernel::kTensorCStrideIdx)}, - {tensor_C.device_data(), tensor_C.stride(Conv2d::ImplicitGemmKernel::kTensorCStrideIdx)}, - {alpha, beta} // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} ); status = reduction_op.initialize(reduction_args, nullptr); @@ -302,7 +313,11 @@ public: } } bool passed = false; - + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + tensor_D_computed.sync_host(); #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED @@ -326,10 +341,6 @@ public: alpha, beta); - cudaError_t result = cudaDeviceSynchronize(); - EXPECT_EQ(result, cudaSuccess) << " device reference error: " - << cudaGetErrorString(result); - // sync host (copy device data to host) for dumping error output in case of mismatches tensor_D_reference.sync_host(); @@ -445,7 +456,7 @@ bool TestAllConv2d( Conv2dProblemVector const *problem_vectors[] = { &conv_test_sizes, // run user specified sizes &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes - &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes + //&conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes #if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled #endif @@ -467,7 +478,7 @@ bool TestAllConv2d( // Procedurally disable certain cases // - // CUTLASS DGRAD's unity stride specialization only support stride {1, 1} + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == @@ -477,6 +488,18 @@ bool TestAllConv2d( } } + // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} + // Although strided dgrad works for all stride combinations, we are only going + // to run strided dgrad for non-unity strides + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + // // Test // @@ -491,7 +514,7 @@ bool TestAllConv2d( if (!passed) { return false; } - + // test mode = convolution passed = testbed.run( conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), @@ -503,6 +526,30 @@ bool TestAllConv2d( } } + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep diff --git a/test/unit/conv/device/conv2d_testbed_interleaved.h b/test/unit/conv/device/conv2d_testbed_interleaved.h index 06ab207d..366d0682 100644 --- a/test/unit/conv/device/conv2d_testbed_interleaved.h +++ b/test/unit/conv/device/conv2d_testbed_interleaved.h @@ -82,7 +82,7 @@ public: >; using ReductionDevice = cutlass::reduction::device::ReduceSplitK; - + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; public: @@ -245,10 +245,20 @@ public: cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), problem_size.split_k_slices, cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), - {reinterpret_cast (workspace.get()), tensor_C.stride(Conv2d::ImplicitGemmKernel::kTensorCStrideIdx)}, - {tensor_D_computed.device_data(), tensor_C.stride(Conv2d::ImplicitGemmKernel::kTensorCStrideIdx)}, - {tensor_C.device_data(), tensor_C.stride(Conv2d::ImplicitGemmKernel::kTensorCStrideIdx)}, - {alpha, beta} // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} ); status = reduction_op.initialize(reduction_args, nullptr); diff --git a/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu b/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu index dbc55332..69497fb8 100644 --- a/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu +++ b/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu @@ -36,51 +36,6 @@ #include "conv2d_testbed.h" -//////////////////////////////////////////////////////////////////////////////// -TEST(SM50_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 32x64_8x2_32x32x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 32, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - //////////////////////////////////////////////////////////////////////////////// TEST(SM50_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, 64x64_8x2_32x32x8) { diff --git a/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu b/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu index 6cf9b15f..92e1e4a9 100644 --- a/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu @@ -37,95 +37,6 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 64x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, @@ -172,96 +83,6 @@ TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_s } -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 128x128_8x4_64x32x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 8>, - cutlass::gemm::GemmShape<64, 32, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = cutlass::complex; - using ElementB = cutlass::complex; - using ElementC = cutlass::complex; - using ElementAccumulator = cutlass::complex; - using ElementCompute = cutlass::complex; - - - /// Device-level Conv2d instance - using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAddComplex, - cutlass::conv::IteratorAlgorithm::kOptimized - >::Kernel; - - using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, 128x128_8x4_64x32x8) { diff --git a/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu index 5645c90d..fea0519b 100644 --- a/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu @@ -37,151 +37,6 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 64x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 128x128_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic - >::Kernel; - - using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; - - test::conv::device::Conv2dProblemVector user_size; - - user_size.push_back(cutlass::conv::Conv2dProblemSize( - {1, 8, 8, 4}, // input size (NHWC) - {8, 1, 1, 4}, // filter size (KRSC) - {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) - {1, 1}, // stride (stride_h, stride_w) - {1, 1} // dilation (dilation_h, dilation_w) - )); - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d(user_size)); - -} - //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, 128x128_8x4_64x32x8) { @@ -227,51 +82,6 @@ TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt } -//////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, - 32x64_8x4_32x64x8) { - - /// Conv operation element types for the Gemm equivalent (ImplicitGemm) - using ElementA = float; - using ElementB = float; - using ElementC = float; - using ElementAccumulator = float; - using ElementCompute = float; - - - /// Device-level Conv2d instance - using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< - ElementA, - cutlass::layout::TensorNHWC, - ElementB, - cutlass::layout::TensorNHWC, - ElementC, - cutlass::layout::TensorNHWC, - ElementAccumulator, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<32, 64, 8>, - cutlass::gemm::GemmShape<1, 1, 1>, - cutlass::epilogue::thread::LinearCombination< - ElementC, - 1, - ElementAccumulator, - ElementCompute - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized - >::Kernel; - - using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; - - /// Run all unit test sizes with device-level Conv2d instance - EXPECT_TRUE(test::conv::device::TestAllConv2d()); - -} - //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, 128x128_8x4_64x32x8) { diff --git a/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/test/unit/conv/device/conv2d_with_broadcast_testbed.h new file mode 100644 index 00000000..f85134a4 --- /dev/null +++ b/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -0,0 +1,551 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv2dWithBroadcast { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + TestbedConv2dWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 3; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + int smem_size = int(sizeof(typename Conv2d::ImplicitGemmKernel::SharedStorage)); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << tensor_D_reference.host_view() << "\n" + << "\nD computed:\n" << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllConv2dWithBroadcast( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithBroadcast testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/test/unit/conv/device/conv2d_with_reduction_testbed.h b/test/unit/conv/device/conv2d_with_reduction_testbed.h new file mode 100644 index 00000000..97526b70 --- /dev/null +++ b/test/unit/conv/device/conv2d_with_reduction_testbed.h @@ -0,0 +1,568 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv2dWithReduction { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + using ElementT = typename EpilogueOutputOp::ElementTensor; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + + cutlass::HostTensor tensor_Reduction; + cutlass::HostTensor tensor_Tensor; + + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + TestbedConv2dWithReduction( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 3; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + tensor_Reduction.resize({ + (problem_size.N * problem_size.P * problem_size.Q), + (problem_size.K - 1 + Conv2d::ThreadblockShape::kN) / Conv2d::ThreadblockShape::kN + }); + + tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K}); + + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + int smem_size = int(sizeof(typename Conv2d::ImplicitGemmKernel::SharedStorage)); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Reduction.device_data(), + tensor_Tensor.device_data(), + static_cast(tensor_Reduction.stride()[0]), + static_cast(tensor_Tensor.stride()[0]) + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << tensor_D_reference.host_view() << "\n" + << "\nD computed:\n" << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllConv2dWithReduction( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithReduction testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/test/unit/conv/device/conv3d_testbed.h b/test/unit/conv/device/conv3d_testbed.h index 87ac39ab..4f834f13 100644 --- a/test/unit/conv/device/conv3d_testbed.h +++ b/test/unit/conv/device/conv3d_testbed.h @@ -81,7 +81,8 @@ public: >; using ReductionDevice = cutlass::reduction::device::ReduceSplitK; - + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + public: /// Initialization @@ -281,10 +282,20 @@ public: cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), problem_size.split_k_slices, cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), - {reinterpret_cast (workspace.get()), tensor_C.stride(Conv3d::ImplicitGemmKernel::kTensorCStrideIdx)}, - {tensor_D_computed.device_data(), tensor_C.stride(Conv3d::ImplicitGemmKernel::kTensorCStrideIdx)}, - {tensor_C.device_data(), tensor_C.stride(Conv3d::ImplicitGemmKernel::kTensorCStrideIdx)}, - {alpha, beta} // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv3d::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv3d::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv3d::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} ); status = reduction_op.initialize(reduction_args, nullptr); @@ -304,6 +315,38 @@ public: } bool passed = false; + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta + ); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else cutlass::reference::host::Conv3d< ElementA, LayoutA, @@ -323,8 +366,7 @@ public: alpha, beta ); - - tensor_D_computed.sync_host(); +#endif passed = cutlass::reference::host::TensorEquals( tensor_D_computed.host_view(), diff --git a/test/unit/core/complex.cu b/test/unit/core/complex.cu index 59812d6e..04a6798b 100644 --- a/test/unit/core/complex.cu +++ b/test/unit/core/complex.cu @@ -32,6 +32,7 @@ #include "../common/cutlass_unit_test.h" #include "cutlass/complex.h" +#include "cutlass/constants.h" #include "cutlass/numeric_conversion.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -85,6 +86,42 @@ TEST(complex, f16_to_f32_conversion) { //////////////////////////////////////////////////////////////////////////////////////////////////// +TEST(complex, exp_f32) { + + cutlass::complex Z[] = { + {1, 1}, + {2 , cutlass::constants::pi()/2.0f }, + {0.5f, cutlass::constants::pi() }, + {0.25f, cutlass::constants::pi()*3/4.0f }, + {0, 0}, + }; + + cutlass::complex Expected[] = { + {1.4686939399158851, 2.2873552871788423}, + {4.524491950137825e-16, 7.38905609893065}, + {-1.6487212707001282, 2.019101226849069e-16}, + {-0.9079430793557842, 0.9079430793557843}, + {1, 0} + }; + + double tolerance = 0.00001; + + for (int i = 0; cutlass::real(Z[i]); ++i) { + double e_r = cutlass::real(Expected[i]); + double e_i = cutlass::real(Expected[i]); + + cutlass::complex got = cutlass::exp(Z[i]); + float g_r = cutlass::real(got); + float g_i = cutlass::real(got); + + EXPECT_TRUE( + std::abs(g_r - e_r) < tolerance && std::abs(g_i - e_i) < tolerance + ) << "Expected(" << Expected[i] << "), Got(" << got << ")"; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + namespace test { /// Thorough testing for basic complex math operators. Uses std::complex as a reference. diff --git a/test/unit/core/functional.cu b/test/unit/core/functional.cu index a3b98f70..038576e4 100644 --- a/test/unit/core/functional.cu +++ b/test/unit/core/functional.cu @@ -29,6 +29,7 @@ #include "../common/cutlass_unit_test.h" #include "cutlass/functional.h" +#include "cutlass/core_io.h" #include "cutlass/layout/matrix.h" #include "cutlass/util/host_tensor.h" @@ -78,16 +79,16 @@ __global__ void trinary_operator( Operator op; - Element a_x = *a; - Element b_x = *b; - Element c_x = *c; + Element a_x = a[blockIdx.x]; + Element b_x = b[blockIdx.x]; + Element c_x = c[blockIdx.x]; CUTLASS_PRAGMA_NO_UNROLL for (int i = 0; i < Iterations; ++i) { c_x = op(a_x, b_x, c_x); } - *d = c_x; + d[blockIdx.x] = c_x; } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -421,3 +422,67 @@ TEST(Functional, multiply_add_bf16x17) { ///////////////////////////////////////////////////////////////////////////////////////////////// +template +cutlass::Quaternion random_quaternion(int range) { + return cutlass::Quaternion{ + T((rand() % range * 2) - range), + T((rand() % range * 2) - range), + T((rand() % range * 2) - range), + T((rand() % range * 2) - range) + }; +} + +template +void Functional_multiply_add_QuaternionT() { + + using Element = cutlass::Quaternion; + using Operator = cutlass::multiply_add; + using HostTensor = cutlass::HostTensor; + + int const kM = 128; + int const kRange = 8; + + HostTensor A({kM, 1}); + HostTensor B({kM, 1}); + HostTensor C({kM, 1}); + HostTensor D({kM, 1}); + + srand(2021); + + for (int m = 0; m < kM; ++m) { + A.at({m, 0}) = random_quaternion(kRange); + B.at({m, 0}) = random_quaternion(kRange); + C.at({m, 0}) = random_quaternion(kRange); + } + + A.sync_device(); + B.sync_device(); + C.sync_device(); + D.sync_device(); + + test::core::kernel::trinary_operator<<< dim3(kM,1), dim3(1,1) >>>( + D.device_data(), + A.device_data(), + B.device_data(), + C.device_data() + ); + + D.sync_host(); + + for (int m = 0; m < kM; ++m) { + + Element a = A.at({m, 0}); + Element b = B.at({m, 0}); + Element c = C.at({m, 0}); + Element got = D.at({m, 0}); + Element expected = a * b + c; + + EXPECT_TRUE(got == expected); + } +} + +TEST(Functional, multiply_add_quaternion_f32) { + Functional_multiply_add_QuaternionT(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/core/matrix.cu b/test/unit/core/matrix.cu index f94605d7..b4be22c4 100644 --- a/test/unit/core/matrix.cu +++ b/test/unit/core/matrix.cu @@ -32,6 +32,7 @@ #include "../common/cutlass_unit_test.h" #include "cutlass/matrix.h" +#include "cutlass/core_io.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/thread/linear_combination.cu b/test/unit/epilogue/thread/linear_combination.cu index 48275ea2..262b10e7 100644 --- a/test/unit/epilogue/thread/linear_combination.cu +++ b/test/unit/epilogue/thread/linear_combination.cu @@ -122,13 +122,14 @@ TEST(Epilogue_thread_linear_combination, device_side_f16_f32_ptr) { ///////////////////////////////////////////////////////////////////////////////////////////////// + TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { using Element = cutlass::half_t; using ElementOutput = cutlass::half_t; int const kCount = 8; - using LinearCombination = cutlass::epilogue::thread::LinearCombinationGELU< + using LinearCombinationGELU = cutlass::epilogue::thread::LinearCombinationGELU< ElementOutput, kCount, Element, @@ -137,9 +138,9 @@ TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { Element alpha = Element(1); Element beta = Element(0); - typename LinearCombination::Params params(&alpha, &beta); + typename LinearCombinationGELU::Params params(&alpha, &beta); - LinearCombination linear_combination_op(params); + LinearCombinationGELU linear_combination_op(params); cutlass::Array accum; @@ -157,4 +158,4 @@ TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { } } -///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/threadblock/CMakeLists.txt b/test/unit/epilogue/threadblock/CMakeLists.txt index b987a05c..04475b33 100755 --- a/test/unit/epilogue/threadblock/CMakeLists.txt +++ b/test/unit/epilogue/threadblock/CMakeLists.txt @@ -32,4 +32,5 @@ cutlass_test_unit_add_executable( epilogue_volta_tensor_op.cu epilogue_wmma_tensor_op_sm70.cu epilogue_planar_complex.cu + epilogue_with_reduction_tensor_op.cu ) diff --git a/test/unit/epilogue/threadblock/epilogue_simt.cu b/test/unit/epilogue/threadblock/epilogue_simt.cu index 72b86cfa..9821b24a 100644 --- a/test/unit/epilogue/threadblock/epilogue_simt.cu +++ b/test/unit/epilogue/threadblock/epilogue_simt.cu @@ -32,6 +32,7 @@ #include "cutlass/aligned_buffer.h" #include "cutlass/complex.h" +#include "cutlass/quaternion.h" #include "cutlass/gemm/warp/mma_simt.h" #include "cutlass/gemm/warp/mma_simt_policy.h" @@ -1088,4 +1089,80 @@ TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f64_128x128_32x64x8) { EXPECT_TRUE(passed); } -/////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Quaternion-valued single-precision +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Epilogue_threadblock_epilogue, simt_quaternion_f32_32x64_32x64x8) { + + // + // Define the warp-level matrix multiply + // + + using Element = cutlass::Quaternion; + using ElementOutput = Element; + using ElementAccumulator = Element; + using ElementCompute = Element; + int const kElementsPerAccess = 1; + + using Shape = cutlass::gemm::GemmShape<32, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; + + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using ElementOutput = Element; + using ElementAccumulator = Element; + using ElementCompute = Element; + + using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< + WarpShape, + Element, + LayoutA, + Element, + LayoutB, + Element, + LayoutC, + cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape<4, 8>, + cutlass::layout::RowMajorInterleaved<2>, + cutlass::gemm::GemmShape<2, 2, 1> + > + >; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + kElementsPerAccess, + ElementAccumulator, + ElementCompute + >; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< + Shape, + WarpMmaSimt, + OutputOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} diff --git a/test/unit/epilogue/threadblock/epilogue_with_reduction_tensor_op.cu b/test/unit/epilogue/threadblock/epilogue_with_reduction_tensor_op.cu new file mode 100644 index 00000000..eb1181db --- /dev/null +++ b/test/unit/epilogue/threadblock/epilogue_with_reduction_tensor_op.cu @@ -0,0 +1,875 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for thread-level GEMM +*/ + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" + +#include "cutlass/epilogue/thread/linear_combination_drelu.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" +#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "epilogue_with_reduction_testbed.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Disable selected tests on CUDA 11.1 +// +// +#define ENABLE_BLOCKED_TESTS (!(__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 1)) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_64x64_64x64x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_64x64_64x64x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_128x128_64x64x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_128x128_64x64x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_128x64_64x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if ENABLE_BLOCKED_TESTS + +TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_128x64_64x32x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_64x128_32x64x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_64x128_32x64x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<64, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_128x256_64x64x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 256, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_128x256_64x64x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<128, 256, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_256x128_64x64x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = float; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<256, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_256x128_64x64x8) { + + // + // Define the warp-level matrix multiply + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + int const kPartitionsK = 1; + + using Shape = cutlass::gemm::GemmShape<256, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = ElementAccumulator; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutC = cutlass::layout::RowMajor; + + using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + LayoutC>::Type; + + // + // Output operator + // + + using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + ElementAccumulator, + ElementAccumulator, + ElementOutput, + ElementOutput, + kElementsPerAccess + >; + + using ReductionOp = cutlass::plus; + + // + // Define the epilogue + // + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + kPartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + kElementsPerAccess + >::Epilogue; + + // + // Instantiate epilogue + // + + EpilogueWithReductionTestbed testbed; + + bool passed = testbed.run_all(); + + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h b/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h new file mode 100644 index 00000000..3a8ad743 --- /dev/null +++ b/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h @@ -0,0 +1,429 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_with_reduction_threadblock( + typename Epilogue::ElementVector *ptr_Reduction, + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + typename Epilogue::TensorTileIterator::Params params_Tensor, + typename Epilogue::TensorTileIterator::Element *ptr_Tensor, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::TensorTileIterator iterator_T( + params_Tensor, + ptr_Tensor, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + accumulator_iterator.load(accumulators); + +#if 0 + // For debugging, enable this block of code to fill each accumulator element with its + // source thread ID. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < accumulators.size(); ++i) { + typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); + //typename Epilogue::WarpMmaOperator::ElementC x(i); + accumulators[i] = x; + } + + /* + #pragma unroll 1 + for (int tid = 0; tid < 32; ++tid) { + if (tid == thread_idx) { + printf("\nT%d: ", thread_idx); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < accumulators.size(); ++i) { + printf("%d ", int(accumulators[i])); + } + } + } + + if (thread_idx == 0) { + printf("\n\n"); + } + */ + + __syncthreads(); + +#endif + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue(output_op, ptr_Reduction, iterator_D, accumulators, iterator_C, iterator_T); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpilogueWithReductionTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementTensor = typename Epilogue::TensorTileIterator::Element; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensor accumulator_tensor; + cutlass::HostTensor source_tensor; + cutlass::HostTensor output_tensor; + cutlass::HostTensor additional_tensor; + cutlass::HostTensor reduction_tensor; + + +public: + + // + // Methods + // + + EpilogueWithReductionTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + additional_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + reduction_tensor({1, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + + cutlass::reference::host::TensorFill(additional_tensor.host_view(), ElementTensor(1)); + } + + bool run_all() { + + /* + double alpha_values[] = {1, 0, 2.25}; + double beta_values[] = {0, 1, -1.25}; + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + int m = quantized_size.row() - m_idx * 3; + int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; + + for (double const &alpha : alpha_values) { + for (double const &beta : beta_values) { + + bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); + + if (!passed) { + return false; + } + } + } + } + } + return true; + */ + + double alpha = 1; + double beta = 0; + + return run( + {quantized_size.row(), quantized_size.column()}, + {cutlass::from_real(alpha), cutlass::from_real(beta)}); + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ElementOutput default_output = ElementOutput(-127); + ElementAccumulator default_reduction = ElementAccumulator(); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + cutlass::reference::host::TensorFill(reduction_tensor.host_view(), default_reduction); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + additional_tensor.sync_device(); + reduction_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); + typename Epilogue::TensorTileIterator::Params params_T(additional_tensor.device_ref().layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_with_reduction_threadblock<<< grid, block >>>( + reduction_tensor.device_data(), + params_D, + output_tensor.device_data(), + params_C, + source_tensor.device_data(), + params_T, + additional_tensor.device_data(), + output_params, + problem_size, + accumulator_tensor.device_view()); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + reduction_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + // + // The output has two parts: + // - GEMM tensor epilogue in canonical layout + // - partial reduction in canonical row-major layout + // + + // Verify the GEMM tensor output + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ElementOutput got = output_tensor.at(coord); + + ElementOutput expected; + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + expected = ElementOutput(output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ElementCompute(source_tensor.at(coord))); + } + else { + expected = default_output; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // Verify the partial reduction + for (int c = 0; c < quantized_size.column(); ++c) { + + ElementAccumulator reduction_acc = ElementAccumulator(); + + for (int r = 0; r < quantized_size.row(); ++r) { + reduction_acc += accumulator_tensor.at({r, c}); + } + + ElementAccumulator expected = default_reduction; + ElementAccumulator got = reduction_tensor.at({0, c}); + + if (c < problem_size.column()) { + expected = reduction_acc; + } + else { + expected = default_reduction; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - reduction element (" << c << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + } + } + + // + // Report results on error + // + + if (errors) { + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/threadblock/output_tile_threadmap.cu b/test/unit/epilogue/threadblock/output_tile_threadmap.cu index 19824e8a..cb59d588 100644 --- a/test/unit/epilogue/threadblock/output_tile_threadmap.cu +++ b/test/unit/epilogue/threadblock/output_tile_threadmap.cu @@ -63,7 +63,7 @@ struct OutputTileThreadMapExpr { }; int const kWarpSize = 32; - int const kMemoryAccessSize = 128; // size in bytes of the preferred memory access size + int const kMemoryAccessSize = 256; // size in bytes of the preferred memory access size // // Data members diff --git a/test/unit/epilogue/threadblock/testbed.h b/test/unit/epilogue/threadblock/testbed.h index ba5241af..e58adbbb 100644 --- a/test/unit/epilogue/threadblock/testbed.h +++ b/test/unit/epilogue/threadblock/testbed.h @@ -28,13 +28,14 @@ #pragma once #include +#include #include "../../common/cutlass_unit_test.h" #include "cutlass/aligned_buffer.h" #include "cutlass/half.h" #include "cutlass/complex.h" - +#include "cutlass/quaternion.h" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/util/host_tensor.h" @@ -307,10 +308,18 @@ public: ElementOutput expected; if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { - expected = ElementOutput(output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + - output_params.beta * ElementCompute(source_tensor.at(coord))); - } - else { + ElementCompute intermediate = + output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ElementCompute(source_tensor.at(coord)); + + if (std::numeric_limits::is_integer + && !std::numeric_limits::is_integer) { + std::fesetround(FE_TONEAREST); + expected = ElementOutput(std::nearbyint(float(cutlass::real(intermediate)))); + } else { + expected = ElementOutput(intermediate); + } + } else { expected = default_output; } @@ -322,7 +331,11 @@ public: << "-------\n" << "Error - output element (" << coord << ") - expected: " << OutputIO(expected) - << ", got: " << OutputIO(got) << std::endl; + << ", got: " << OutputIO(got) + << ", accum: " << (accumulator_tensor.at(coord)) + << ", source: " << OutputIO(source_tensor.at(coord)) + << ", alpha: " << (output_params.alpha) + << ", beta: " << (output_params.beta) << "\n"; ++errors; } diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 87e49598..5590bbe9 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -34,6 +34,7 @@ add_custom_target( cutlass_test_unit_gemm_device_wmma cutlass_test_unit_gemm_device_tensorop_planar_complex cutlass_test_unit_gemm_device_sparse_tensorop_sm80 + cutlass_test_unit_gemv_device ) add_custom_target( @@ -50,6 +51,7 @@ add_custom_target( test_unit_gemm_device_wmma test_unit_gemm_device_tensorop_planar_complex test_unit_gemm_device_sparse_tensorop_sm80 + test_unit_gemv_device ) cutlass_test_unit_add_executable( @@ -66,6 +68,11 @@ cutlass_test_unit_add_executable( simt_cgemm_tn_sm50.cu simt_cgemm_tt_sm50.cu + simt_qgemm_nn_sm50.cu + simt_qgemm_nt_sm50.cu + simt_qgemm_tn_sm50.cu + simt_qgemm_tt_sm50.cu + simt_dgemm_nn_sm50.cu simt_dgemm_nt_sm50.cu simt_dgemm_tn_sm50.cu @@ -203,6 +210,7 @@ cutlass_test_unit_add_executable( gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu + ) cutlass_test_unit_add_executable( @@ -332,3 +340,36 @@ cutlass_test_unit_add_executable( gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu ) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemv_device + + BATCH_SOURCES ON + BATCH_SIZE 4 + + gemv.cu +) + +if (NOT CUDA_COMPILER MATCHES "[Cc]lang") + +add_dependencies( + cutlass_test_unit_gemm_device + cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop + ) + +add_dependencies( + test_unit_gemm_device + test_unit_gemm_device_gemm_with_fused_epilogue_tensorop + ) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop + + gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu + gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu + + gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu +) + +endif() + diff --git a/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu b/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu index 120cae05..c3213fd3 100644 --- a/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu +++ b/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu @@ -41,7 +41,6 @@ #include "testbed.h" #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 32x32x16_16x16x16) { @@ -209,4 +208,45 @@ TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x128x16_32x64x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Gemm_f64an_f64at_f64at_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using LayoutA = cutlass::layout::AffineRank2ColumnMajor; + using LayoutB = cutlass::layout::AffineRank2RowMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + double, + LayoutA, + double, + LayoutB, + ElementOutput, + LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu b/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu index 8f742573..de5182d3 100644 --- a/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu +++ b/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu @@ -209,4 +209,45 @@ TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x128x16_32x64x16) { ///////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Gemm_f64at_f64an_f64at_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using LayoutA = cutlass::layout::AffineRank2RowMajor; + using LayoutB = cutlass::layout::AffineRank2ColumnMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + double, + LayoutA, + double, + LayoutB, + ElementOutput, + LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu index f974cf16..cd8f0c0a 100644 --- a/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu +++ b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu @@ -126,6 +126,222 @@ TEST(SM70_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_884, 64x64x32_32 } +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_s884_nn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_s884_nn : gemm_planar_complex_s884_nn_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmPlanarComplex_f16n_f16n_f32n_tensor_op_f32_884, 128x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_f16_s884_f16_nn_128x64_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + cutlass::half_t, + 8, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_f16_s884_f16_nn_128x64_32x2 : gemm_planar_complex_f16_s884_f16_nn_128x64_32x2_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmPlanarComplex_f16n_f16n_f16n_tensor_op_f32_884, 128x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_f16_s884_f16_nn_64x128_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + cutlass::half_t, + 8, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_f16_s884_f16_nn_64x128_32x2 : gemm_planar_complex_f16_s884_f16_nn_64x128_32x2_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmPlanarComplex_f16n_f16n_f16n_tensor_op_f32_884, 64x128x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_f16_s884_f16_tt_128x64_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + cutlass::half_t, + 8, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_f16_s884_f16_tt_128x64_32x2 : gemm_planar_complex_f16_s884_f16_tt_128x64_32x2_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmPlanarComplex_f16t_f16t_f16n_tensor_op_f32_884, 128x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_f16_s884_f16_tt_64x128_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + cutlass::half_t, + 8, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_f16_s884_f16_tt_64x128_32x2 : gemm_planar_complex_f16_s884_f16_tt_64x128_32x2_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmPlanarComplex_f16t_f16t_f16n_tensor_op_f32_884, 64x128x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + //////////////////////////////////////////////////////////////////////////////// #endif // #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu index f66ba86d..3ddeea33 100644 --- a/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu +++ b/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu @@ -81,6 +81,48 @@ TEST(SM80_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_16816, 64x64x32_ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); } +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_f16_s16816_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_f16_s16816_tn : gemm_planar_complex_f16_s16816_tn_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmPlanarComplex_f16t_f16n_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} //////////////////////////////////////////////////////////////////////////////// @@ -127,6 +169,49 @@ TEST(SM80_Device_GemmPlanarComplex_f16h_f16c_f32n_tensor_op_f32_16816, 64x64x32_ //////////////////////////////////////////////////////////////////////////////// +using gemm_planar_complex_f16_s16816_hc_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kConjugate, + 8, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kConjugate, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_f16_s16816_hc : gemm_planar_complex_f16_s16816_hc_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmPlanarComplex_f16h_f16c_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + +//////////////////////////////////////////////////////////////////////////////// + using gemm_planar_complex_s16816_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, @@ -168,6 +253,50 @@ TEST(SM80_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_16816, 64x64x32_ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); } + +//////////////////////////////////////////////////////////////////////////////// + +using gemm_planar_complex_f16_s16816_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_f16_s16816_nt : gemm_planar_complex_f16_s16816_nt_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmPlanarComplex_f16n_f16t_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} + //////////////////////////////////////////////////////////////////////////////// using gemm_planar_complex_s16816_ch_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< @@ -213,4 +342,46 @@ TEST(SM80_Device_GemmPlanarComplex_f16c_f16h_f32n_tensor_op_f32_16816, 64x64x32_ //////////////////////////////////////////////////////////////////////////////// +using gemm_planar_complex_cf16_s16816_ch_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::ComplexTransform::kConjugate, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kConjugate, + 8, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + float, + 4, + float, + float + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd +>::GemmKernel; + +struct gemm_planar_complex_cf16_s16816_ch : gemm_planar_complex_cf16_s16816_ch_base { + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmPlanarComplex_f16c_f16h_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); +} +//////////////////////////////////////////////////////////////////////////////// + #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu b/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu new file mode 100644 index 00000000..4a6f06cf --- /dev/null +++ b/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu @@ -0,0 +1,458 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_gemm_with_broadcast.h" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes: +/// +/// Z = GEMM+Bias+ReLu +/// T = Relu conditional +/// +template +struct GemmWithBiasReluReferenceOp { + + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + GemmWithBiasReluReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { + + ElementCompute kThreshold = ElementCompute(); + + ElementCompute z_full = binary_op(gemm, bias); + + bool conditional = (z_full >= kThreshold); + + if (!conditional) { + z_full = kThreshold; + } + + Z = ElementZ(z_full); + T = ElementT(conditional); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + cutlass::half_t, + float, + float, + cutlass::half_t, + cutlass::half_t, + 8, + cutlass::epilogue::thread::GELU_taylor + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + test::gemm::device::TestAllGemmWithBroadcast(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + cutlass::half_t, + float, + float, + cutlass::half_t, + cutlass::half_t, + 8, + cutlass::epilogue::thread::GELU_taylor + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + test::gemm::device::TestAllGemmWithBroadcast(); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< + cutlass::half_t, + float, + float, + cutlass::half_t, + 8, + true + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + test::gemm::device::TestAllGemmWithBroadcast >(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< + cutlass::half_t, + float, + float, + cutlass::half_t, + 8, + true + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + test::gemm::device::TestAllGemmWithBroadcast >(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if defiend(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x5_64x64x32_16x8x16) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + cutlass::half_t, + float, + float, + cutlass::half_t, + cutlass::half_t, + 8, + cutlass::epilogue::thread::GELU_taylor + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 5, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + test::gemm::device::TestAllGemmWithBroadcast(); +} + +TEST(SM80_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x5_64x64x32_16x8x16) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< + cutlass::half_t, + float, + float, + cutlass::half_t, + 8, + true + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 5, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + test::gemm::device::TestAllGemmWithBroadcast>(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x4_64x64x32_16x8x16) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + cutlass::half_t, + float, + float, + cutlass::half_t, + cutlass::half_t, + 8, + cutlass::epilogue::thread::GELU_taylor + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + test::gemm::device::TestAllGemmWithBroadcast(); +} + +TEST(SM80_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x4_64x64x32_16x8x16) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< + cutlass::half_t, + float, + float, + cutlass::half_t, + 8, + true + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + test::gemm::device::TestAllGemmWithBroadcast>(); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x3_64x64x32_16x8x16) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + cutlass::half_t, + float, + float, + cutlass::half_t, + cutlass::half_t, + 8, + cutlass::epilogue::thread::GELU_taylor + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 3, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + test::gemm::device::TestAllGemmWithBroadcast(); +} + +TEST(SM80_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x3_64x64x32_16x8x16) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< + cutlass::half_t, + float, + float, + cutlass::half_t, + 8, + true + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 3, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + test::gemm::device::TestAllGemmWithBroadcast >(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu b/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu new file mode 100644 index 00000000..b2be97af --- /dev/null +++ b/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu @@ -0,0 +1,378 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/kernel/default_gemm_with_reduction.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/epilogue/thread/linear_combination_drelu.h" +#include "cutlass/epilogue/thread/linear_combination_dgelu.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_gemm_with_reduction.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct dReluLambda { + float operator()(float d_y, float t) { + if (t <= 0) { + d_y = 0; + } + return d_y; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + float, + float, + cutlass::half_t, + cutlass::half_t, + 8 + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithReduction< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + EpilogueOutputOp, + cutlass::plus, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ReferenceOp = test::gemm::device::GemmWithReductionReference< + Gemm, + dReluLambda + >; + + test::gemm::device::TestGemmWithReduction( + {520, 264, 96}, + cutlass::gemm::GemmUniversalMode::kGemm, + 2, + float(1.25), + float(2.25) + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 256x128x32_64x64x8) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + float, + float, + cutlass::half_t, + cutlass::half_t, + 8 + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithReduction< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + EpilogueOutputOp, + cutlass::plus, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ReferenceOp = test::gemm::device::GemmWithReductionReference< + Gemm, + dReluLambda + >; + + test::gemm::device::TestGemmWithReduction( + {520, 264, 96}, + cutlass::gemm::GemmUniversalMode::kGemm, + 1, + float(1.25), + float(2.25) + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + float, + float, + cutlass::half_t, + cutlass::half_t, + 8 + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithReduction< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + EpilogueOutputOp, + cutlass::plus, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ReferenceOp = test::gemm::device::GemmWithReductionReference< + Gemm, + dReluLambda + >; + + test::gemm::device::TestGemmWithReduction( + {520, 264, 96}, + cutlass::gemm::GemmUniversalMode::kGemm, + 2, + float(1.25), + float(2.25) + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 256x128x32_64x64x8) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + float, + float, + cutlass::half_t, + cutlass::half_t, + 8 + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithReduction< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + EpilogueOutputOp, + cutlass::plus, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ReferenceOp = test::gemm::device::GemmWithReductionReference< + Gemm, + dReluLambda + >; + + test::gemm::device::TestGemmWithReduction( + {520, 264, 96}, + cutlass::gemm::GemmUniversalMode::kGemm, + 1, + float(1.25), + float(2.25) + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +template +struct Gemm_dReLU_packed_bits_reference_op { + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; + using ElementC = typename Gemm::ElementC; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + + // + // Methods + // + + Gemm_dReLU_packed_bits_reference_op() { } + + ElementCompute operator()( + ElementAccumulator d_y, + ElementT t) const { + + ElementCompute result = ElementCompute(d_y); + + bool cond = bool(t); + if (!cond) { + result = ElementCompute(); + } + + return result; + } +}; + +} // namespace device +} // namespace gemm +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_Device_GemmWithReduction_dReLU_conditional_bits_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDReluConditionalBits< + float, + float, + cutlass::half_t, + 8 + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithReduction< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + EpilogueOutputOp, + cutlass::plus, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ReferenceOp = test::gemm::device::Gemm_dReLU_packed_bits_reference_op; + + test::gemm::device::TestGemmWithReduction( + {520, 264, 96}, + cutlass::gemm::GemmUniversalMode::kGemm, + 2, + float(1.25), + float(2.25) + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM70_Device_GemmWithReduction_dReLU_conditional_bits_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDReluConditionalBits< + float, + float, + cutlass::half_t, + 8 + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithReduction< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + EpilogueOutputOp, + cutlass::plus, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ReferenceOp = test::gemm::device::Gemm_dReLU_packed_bits_reference_op; + + test::gemm::device::TestGemmWithReduction( + {520, 264, 96}, + cutlass::gemm::GemmUniversalMode::kGemm, + 2, + float(1.25), + float(2.25) + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if defiend(CUTLASS_ARCH_MMA_SM75_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu b/test/unit/gemm/device/gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu new file mode 100644 index 00000000..9aee2b6b --- /dev/null +++ b/test/unit/gemm/device/gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu @@ -0,0 +1,112 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/kernel/default_gemm_with_reduction.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/epilogue/thread/linear_combination_drelu.h" +#include "cutlass/epilogue/thread/linear_combination_dgelu.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_gemm_with_reduction.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct dReluLambda { + float operator()(float d_y, float t) { + if (t <= 0) { + d_y = 0; + } + return d_y; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmWithReduction_dReLU_bGrad_f16t_f16n_f16n_tensor_op_f32, 128x128x32_64x64x32) { + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< + float, + float, + cutlass::half_t, + cutlass::half_t, + 8 + >; + + using GemmKernel = + typename cutlass::gemm::kernel::DefaultGemmWithReduction< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand + cutlass::half_t, cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + EpilogueOutputOp, + cutlass::plus, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 5, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ReferenceOp = test::gemm::device::GemmWithReductionReference< + Gemm, + dReluLambda + >; + + test::gemm::device::TestGemmWithReduction( + {8, 8, 136}, + cutlass::gemm::GemmUniversalMode::kGemm + ); +} + +#endif // if defiend(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemv.cu b/test/unit/gemm/device/gemv.cu new file mode 100644 index 00000000..8cda6537 --- /dev/null +++ b/test/unit/gemm/device/gemv.cu @@ -0,0 +1,438 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMV interface +*/ + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/kernel/gemv.h" +#include "cutlass/gemm/device/gemv.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { + +template +class TestbedGemv { +public: + + using ElementA = typename Gemv::ElementA; + using LayoutA = typename Gemv::LayoutA; + using ElementB = typename Gemv::ElementB; + using ElementC = typename Gemv::ElementC; + + using ElementAccumulator = typename Gemv::ElementAccumulator; + using ElementCompute = typename Gemv::EpilogueOutputOp::ElementCompute; + + using LayoutV = cutlass::layout::RowMajor; + +private: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + +public: + + // + // Methods + // + + TestbedGemv( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize( + cutlass::MatrixCoord problem_size + ) { + + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size); + tensor_B.resize({problem_size.column(), 1}); + tensor_C.resize({problem_size.row(), 1}); + tensor_D.resize({problem_size.row(), 1}); + reference_D.resize({problem_size.row(), 1}, false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemv::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemv::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemv::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::MatrixCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + std::ofstream file("testbed_universal_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::MatrixCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemv::ElementA, typename Gemv::LayoutA, + typename Gemv::ElementB, LayoutV, + typename Gemv::ElementC, LayoutV, + ElementCompute, ElementAccumulator + >( + {problem_size.row(), 1, problem_size.column()}, + alpha, + tensor_A.host_ref(), + Gemv::kTransformA, + tensor_B.host_ref(), + Gemv::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Runs one problem size + bool run( + cutlass::MatrixCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemv::Arguments arguments{ + problem_size, + {alpha, beta}, + tensor_A.device_ref(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Gemv gemm_op; + + size_t workspace_size = Gemv::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemv() { + + using ElementCompute = typename Gemv::EpilogueOutputOp::ElementCompute; + + int M[] = { + 8, 48, 192, 520 + }; + + int K[] = { + 8, 192, 528 + }; + + double Alpha[] = { + 1, 1.25 + }; + + double Beta[] = { + 0, 1, 1.25 + }; + + for (int m : M) { + for (int k : K) { + for (double alpha : Alpha) { + for (double beta : Beta) { + + TestbedGemv testbed; + + if (!testbed.run({m, k}, ElementCompute(alpha), ElementCompute(beta))) { + return false; + } + } + } + } + } + + return true; +} + +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Gemv_f32n_f32_f32_simt_f32, Simple) { + + using ElementOutput = float; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator>; + + using Gemv = cutlass::gemm::device::Gemv< + cutlass::gemm::kernel::Gemv< + ElementOutput, // Element A + LayoutA, // Layout A + ElementOutput, // Element B + ElementOutput, // Element C + ElementAccumulator, // Element Accumulator + EpilogueOp // Output operator + > + >; + + + EXPECT_TRUE(test::gemm::TestAllGemv()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Gemv_f16n_f16_f32_simt_f32, Simple) { + + using ElementInput = cutlass::half_t; + using ElementOutput = float; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator>; + + using Gemv = cutlass::gemm::device::Gemv< + cutlass::gemm::kernel::Gemv< + ElementInput, // Element A + LayoutA, // Layout A + ElementInput, // Element B + ElementOutput, // Element C + ElementAccumulator, // Element Accumulator + EpilogueOp // Output operator + > + >; + + + EXPECT_TRUE(test::gemm::TestAllGemv()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, Simple) { + + using ElementInput = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator>; + + using Gemv = cutlass::gemm::device::Gemv< + cutlass::gemm::kernel::Gemv< + ElementInput, // Element A + LayoutA, // Layout A + ElementInput, // Element B + ElementOutput, // Element C + ElementAccumulator, // Element Accumulator + EpilogueOp // Output operator + > + >; + + + EXPECT_TRUE(test::gemm::TestAllGemv()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/simt_cgemm_nn_sm50.cu b/test/unit/gemm/device/simt_cgemm_nn_sm50.cu index 680012bc..82db30bf 100644 --- a/test/unit/gemm/device/simt_cgemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_cgemm_nn_sm50.cu @@ -673,66 +673,6 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -1093,96 +1033,6 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 64 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 2 // Threads / Warp: 4 x 8 @@ -1243,36 +1093,6 @@ CUTLASS_TEST_L1(SM50_device_cgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 4 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 128 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 4 // Threads / Warp: 8 x 4 diff --git a/test/unit/gemm/device/simt_cgemm_tn_sm50.cu b/test/unit/gemm/device/simt_cgemm_tn_sm50.cu index a6072d28..df873bc7 100644 --- a/test/unit/gemm/device/simt_cgemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_cgemm_tn_sm50.cu @@ -673,66 +673,6 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -1093,96 +1033,6 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 64 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 2 // Threads / Warp: 4 x 8 @@ -1243,36 +1093,6 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 4 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 128 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 4 // Threads / Warp: 8 x 4 diff --git a/test/unit/gemm/device/simt_cgemm_tt_sm50.cu b/test/unit/gemm/device/simt_cgemm_tt_sm50.cu index 8162905b..bc9a6545 100644 --- a/test/unit/gemm/device/simt_cgemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_cgemm_tt_sm50.cu @@ -673,66 +673,6 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -1093,96 +1033,6 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 64 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 2 // Threads / Warp: 4 x 8 @@ -1243,36 +1093,6 @@ CUTLASS_TEST_L1(SM50_device_cgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 4 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 128 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_cgemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { - using precision = cutlass::complex; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 4 // Threads / Warp: 8 x 4 @@ -1302,4 +1122,3 @@ CUTLASS_TEST_L2(SM50_device_cgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { >; EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) - diff --git a/test/unit/gemm/device/simt_dgemm_nn_sm50.cu b/test/unit/gemm/device/simt_dgemm_nn_sm50.cu index af5dbb7c..3bdae057 100644 --- a/test/unit/gemm/device/simt_dgemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_dgemm_nn_sm50.cu @@ -643,6 +643,45 @@ CUTLASS_TEST_L0(SM50_device_dgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 8 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L0(SM50_device_dgemm_affin2_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { + using precision = double; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using LayoutA = cutlass::layout::AffineRank2ColumnMajor; + using LayoutB = cutlass::layout::AffineRank2ColumnMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, LayoutA, + precision, LayoutB, + precision, LayoutC, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); +} ) + //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 8 x 4 // Threads / Warp: 8 x 4 @@ -673,66 +712,6 @@ CUTLASS_TEST_L1(SM50_device_dgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -973,96 +952,6 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 64 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 2 // Threads / Warp: 4 x 8 @@ -1094,32 +983,3 @@ CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { } ) //////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 4 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 128 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - diff --git a/test/unit/gemm/device/simt_dgemm_nt_sm50.cu b/test/unit/gemm/device/simt_dgemm_nt_sm50.cu index d5cb5e75..6e89e964 100644 --- a/test/unit/gemm/device/simt_dgemm_nt_sm50.cu +++ b/test/unit/gemm/device/simt_dgemm_nt_sm50.cu @@ -643,6 +643,45 @@ CUTLASS_TEST_L0(SM50_device_dgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 8 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L0(SM50_device_dgemm_affine2_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { + using precision = double; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using LayoutA = cutlass::layout::AffineRank2ColumnMajor; + using LayoutB = cutlass::layout::AffineRank2RowMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, LayoutA, + precision, LayoutB, + precision, LayoutC, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); +} ) + //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 8 x 4 // Threads / Warp: 8 x 4 diff --git a/test/unit/gemm/device/simt_dgemm_tn_sm50.cu b/test/unit/gemm/device/simt_dgemm_tn_sm50.cu index 84cb465b..49005732 100644 --- a/test/unit/gemm/device/simt_dgemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_dgemm_tn_sm50.cu @@ -643,6 +643,45 @@ CUTLASS_TEST_L0(SM50_device_dgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 8 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L0(SM50_device_dgemm_affine2_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { + using precision = double; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using LayoutA = cutlass::layout::AffineRank2RowMajor; + using LayoutB = cutlass::layout::AffineRank2ColumnMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, LayoutA, + precision, LayoutB, + precision, LayoutC, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); +} ) + //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 8 x 4 // Threads / Warp: 8 x 4 @@ -673,66 +712,6 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -973,96 +952,6 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 64 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 2 // Threads / Warp: 4 x 8 @@ -1094,32 +983,3 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { } ) //////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 4 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 128 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - diff --git a/test/unit/gemm/device/simt_dgemm_tt_sm50.cu b/test/unit/gemm/device/simt_dgemm_tt_sm50.cu index e9633f5c..f30de3fb 100644 --- a/test/unit/gemm/device/simt_dgemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_dgemm_tt_sm50.cu @@ -163,6 +163,45 @@ CUTLASS_TEST_L0(SM50_device_dgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 8 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 1 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L0(SM50_device_dgemm_affine2_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { + using precision = double; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using LayoutA = cutlass::layout::AffineRank2ColumnMajor; + using LayoutB = cutlass::layout::AffineRank2ColumnMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, LayoutA, + precision, LayoutB, + precision, LayoutC, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); +} ) + //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 4 x 8 @@ -673,66 +712,6 @@ CUTLASS_TEST_L1(SM50_device_dgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -973,96 +952,6 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 64 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 64 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 2 // Threads / Warp: 4 x 8 @@ -1094,32 +983,3 @@ CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { } ) //////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 4 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 128 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_dgemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { - using precision = double; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - diff --git a/test/unit/gemm/device/simt_igemm_nn_sm50.cu b/test/unit/gemm/device/simt_igemm_nn_sm50.cu index be25b520..2ca7cca4 100644 --- a/test/unit/gemm/device/simt_igemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_igemm_nn_sm50.cu @@ -943,36 +943,6 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { - using precision = int; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -1483,36 +1453,6 @@ CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { - using precision = int; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 diff --git a/test/unit/gemm/device/simt_igemm_tn_sm50.cu b/test/unit/gemm/device/simt_igemm_tn_sm50.cu index 2a871ecc..e66b5114 100644 --- a/test/unit/gemm/device/simt_igemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_igemm_tn_sm50.cu @@ -943,36 +943,6 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { - using precision = int; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -1483,36 +1453,6 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { - using precision = int; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -1633,36 +1573,6 @@ CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 4 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 128 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { - using precision = int; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 4 // Threads / Warp: 8 x 4 diff --git a/test/unit/gemm/device/simt_igemm_tt_sm50.cu b/test/unit/gemm/device/simt_igemm_tt_sm50.cu index f86e8e97..52b8e1b9 100644 --- a/test/unit/gemm/device/simt_igemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_igemm_tt_sm50.cu @@ -1633,36 +1633,6 @@ CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 4 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 128 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { - using precision = int; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 4 // Threads / Warp: 8 x 4 diff --git a/test/unit/gemm/device/simt_qgemm_nn_sm50.cu b/test/unit/gemm/device/simt_qgemm_nn_sm50.cu new file mode 100644 index 00000000..3bca3244 --- /dev/null +++ b/test/unit/gemm/device/simt_qgemm_nn_sm50.cu @@ -0,0 +1,855 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/numeric_types.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 1 +// Threadblock: 8 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 1 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L0(SM50_device_qgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 8 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 8 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 16 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 1 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 1 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 16 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 32 x 64 x 8 +CUTLASS_TEST_L0(SM50_device_qgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 2 +// Threadblock: 64 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 16 x 64 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 128 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 64 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 2 +// Threadblock: 128 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 32 x 64 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 32 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; + using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 128 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 4 +// Threadblock: 128 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + diff --git a/test/unit/gemm/device/simt_qgemm_nt_sm50.cu b/test/unit/gemm/device/simt_qgemm_nt_sm50.cu new file mode 100644 index 00000000..fa036155 --- /dev/null +++ b/test/unit/gemm/device/simt_qgemm_nt_sm50.cu @@ -0,0 +1,855 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/numeric_types.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 1 +// Threadblock: 8 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 1 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L0(SM50_device_qgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 8 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 8 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 16 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 1 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 1 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 16 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 32 x 64 x 8 +CUTLASS_TEST_L0(SM50_device_qgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 2 +// Threadblock: 64 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 16 x 64 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 128 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 64 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 2 +// Threadblock: 128 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 32 x 64 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 32 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; + using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 128 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 4 +// Threadblock: 128 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + diff --git a/test/unit/gemm/device/simt_qgemm_tn_sm50.cu b/test/unit/gemm/device/simt_qgemm_tn_sm50.cu new file mode 100644 index 00000000..ec0ec5ff --- /dev/null +++ b/test/unit/gemm/device/simt_qgemm_tn_sm50.cu @@ -0,0 +1,855 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/numeric_types.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 1 +// Threadblock: 8 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 1 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L0(SM50_device_qgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 8 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 8 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 16 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 1 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 1 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 16 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 32 x 64 x 8 +CUTLASS_TEST_L0(SM50_device_qgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 2 +// Threadblock: 64 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 16 x 64 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 128 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 64 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 2 +// Threadblock: 128 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 32 x 64 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 32 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; + using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 128 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 4 +// Threadblock: 128 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::ColumnMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + diff --git a/test/unit/gemm/device/simt_qgemm_tt_sm50.cu b/test/unit/gemm/device/simt_qgemm_tt_sm50.cu new file mode 100644 index 00000000..a143d9b8 --- /dev/null +++ b/test/unit/gemm/device/simt_qgemm_tt_sm50.cu @@ -0,0 +1,855 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/numeric_types.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 1 +// Threadblock: 8 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 1 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L0(SM50_device_qgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 8 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 8 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 1 x 2 +// Threadblock: 16 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 1 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 1 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 16 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 16 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 2 +// Threadblock: 32 x 64 x 8 +CUTLASS_TEST_L0(SM50_device_qgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 2 +// Threadblock: 64 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 16 x 64 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 2 x 4 +// Threadblock: 32 x 128 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 32 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 64 x 32 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 2 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 2 +// Threadblock: 128 x 32 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 32 x 64 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 2 x 2 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 32 x 16 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; + using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 2 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 64 x 8 +CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 4 x 8 +// Warps / Block: 4 x 4 +// Threadblock: 64 x 128 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 4 x 4 +// Threads / Warp: 8 x 4 +// Warps / Block: 4 x 4 +// Threadblock: 128 x 64 x 8 +CUTLASS_TEST_L1(SM50_device_qgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { + using precision = cutlass::Quaternion; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, cutlass::layout::RowMajor, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + diff --git a/test/unit/gemm/device/simt_sgemm_nn_sm50.cu b/test/unit/gemm/device/simt_sgemm_nn_sm50.cu index 64e524b4..a167ad67 100644 --- a/test/unit/gemm/device/simt_sgemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_sgemm_nn_sm50.cu @@ -943,36 +943,6 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { - using precision = float; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -1213,6 +1183,45 @@ CUTLASS_TEST_L0(SM50_device_sgemm_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 8 x 8 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 128 x 128 x 8 +CUTLASS_TEST_L0(SM50_device_sgemm_affine2_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { + using precision = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using LayoutA = cutlass::layout::AffineRank2ColumnMajor; + using LayoutB = cutlass::layout::AffineRank2ColumnMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, LayoutA, + precision, LayoutB, + precision, LayoutC, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); +} ) + //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 4 x 8 @@ -1483,36 +1492,6 @@ CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { - using precision = float; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 diff --git a/test/unit/gemm/device/simt_sgemm_nt_sm50.cu b/test/unit/gemm/device/simt_sgemm_nt_sm50.cu index e520e298..750fe7c3 100644 --- a/test/unit/gemm/device/simt_sgemm_nt_sm50.cu +++ b/test/unit/gemm/device/simt_sgemm_nt_sm50.cu @@ -1213,6 +1213,45 @@ CUTLASS_TEST_L0(SM50_device_sgemm_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 8 x 8 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 128 x 128 x 8 +CUTLASS_TEST_L0(SM50_device_sgemm_affine2_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { + using precision = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using LayoutA = cutlass::layout::AffineRank2ColumnMajor; + using LayoutB = cutlass::layout::AffineRank2RowMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, LayoutA, + precision, LayoutB, + precision, LayoutC, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); +} ) + //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 4 x 8 diff --git a/test/unit/gemm/device/simt_sgemm_nt_sm80.cu b/test/unit/gemm/device/simt_sgemm_nt_sm80.cu index 3a1b5de6..c0691f36 100644 --- a/test/unit/gemm/device/simt_sgemm_nt_sm80.cu +++ b/test/unit/gemm/device/simt_sgemm_nt_sm80.cu @@ -130,6 +130,43 @@ TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x128x8_32x64x1) { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } +TEST(SM80_Device_Gemm_f32an_f32at_f32at_simt_f32, 128x128x8_32x64x1) { + + using Element = float; + using LayoutA = cutlass::layout::AffineRank2ColumnMajor; + using LayoutB = cutlass::layout::AffineRank2RowMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + LayoutA, + Element, + LayoutB, + Element, + LayoutC, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C )); + +} + TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 64x128x8_32x64x1) { using Element = float; @@ -248,7 +285,6 @@ TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x256x8_64x64x1) { } ///////////////////////////////////////////////////////////////////////////////////////////////// - #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/simt_sgemm_tn_sm50.cu b/test/unit/gemm/device/simt_sgemm_tn_sm50.cu index aa3a0d6e..a7d4a698 100644 --- a/test/unit/gemm/device/simt_sgemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_sgemm_tn_sm50.cu @@ -943,36 +943,6 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 2 x 4 -// Threadblock: 16 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { - using precision = float; - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -1213,6 +1183,45 @@ CUTLASS_TEST_L0(SM50_device_sgemm_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 8 x 8 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 128 x 128 x 8 +CUTLASS_TEST_L0(SM50_device_sgemm_affine2_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { + using precision = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using LayoutA = cutlass::layout::AffineRank2RowMajor; + using LayoutB = cutlass::layout::AffineRank2ColumnMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, LayoutA, + precision, LayoutB, + precision, LayoutC, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); +} ) + //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 4 x 8 @@ -1483,36 +1492,6 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 2 x 4 -// Threads / Warp: 4 x 8 -// Warps / Block: 4 x 4 -// Threadblock: 32 x 128 x 16 -CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { - using precision = float; - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 8 x 4 @@ -1633,36 +1612,6 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 4 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 128 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { - using precision = float; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::ColumnMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 4 // Threads / Warp: 8 x 4 diff --git a/test/unit/gemm/device/simt_sgemm_tn_sm80.cu b/test/unit/gemm/device/simt_sgemm_tn_sm80.cu index 9ed5f129..b53aaf5c 100644 --- a/test/unit/gemm/device/simt_sgemm_tn_sm80.cu +++ b/test/unit/gemm/device/simt_sgemm_tn_sm80.cu @@ -44,7 +44,6 @@ #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) //////////////////////////////////////////////////////////////////////////////// - TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 32x64x8_32x64x1) { using Element = float; @@ -132,6 +131,42 @@ TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x128x8_32x64x1) { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } +TEST(SM80_Device_Gemm_f32at_f32an_f32t_simt_f32, 128x128x8_32x64x1) { + + using Element = float; + using LayoutA = cutlass::layout::AffineRank2RowMajor; + using LayoutB = cutlass::layout::AffineRank2ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using Gemm = cutlass::gemm::device::Gemm< + Element, + LayoutA, + Element, + LayoutB, + Element, + LayoutC, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {1}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm( stride_factor_A, stride_factor_B, stride_factor_C )); +} + TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 64x128x8_32x64x1) { using Element = float; diff --git a/test/unit/gemm/device/simt_sgemm_tt_sm50.cu b/test/unit/gemm/device/simt_sgemm_tt_sm50.cu index c148c956..67b7029b 100644 --- a/test/unit/gemm/device/simt_sgemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_sgemm_tt_sm50.cu @@ -1213,6 +1213,45 @@ CUTLASS_TEST_L0(SM50_device_sgemm_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +//////////////////////////////////////////////////////////////////////////////// +// Elements / Thread: 8 x 8 +// Threads / Warp: 8 x 4 +// Warps / Block: 2 x 4 +// Threadblock: 128 x 128 x 8 +CUTLASS_TEST_L0(SM50_device_sgemm_affine2_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { + using precision = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; + + static int const kEpilogueElementsPerAccess = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision, kEpilogueElementsPerAccess, precision, precision>; + + using LayoutA = cutlass::layout::AffineRank2ColumnMajor; + using LayoutB = cutlass::layout::AffineRank2ColumnMajor; + using LayoutC = cutlass::layout::AffineRankN<2>; + + using Gemm = cutlass::gemm::device::Gemm< + precision, LayoutA, + precision, LayoutB, + precision, LayoutC, + precision, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 // Stages + >; + + typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; + typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; + typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; + + EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); +} ) + //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 2 x 2 // Threads / Warp: 4 x 8 @@ -1633,36 +1672,6 @@ CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) -//////////////////////////////////////////////////////////////////////////////// -// Elements / Thread: 4 x 2 -// Threads / Warp: 8 x 4 -// Warps / Block: 4 x 4 -// Threadblock: 128 x 32 x 16 -CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { - using precision = float; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; - using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; - - static int const kEpilogueElementsPerAccess = 1; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< - precision, kEpilogueElementsPerAccess, precision, precision>; - - using Gemm = cutlass::gemm::device::Gemm< - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, cutlass::layout::RowMajor, - precision, - cutlass::arch::OpClassSimt, - cutlass::arch::Sm50, - ThreadblockShape, WarpShape, InstructionShape, - EpilogueOutputOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 2 // Stages - >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); -} ) - //////////////////////////////////////////////////////////////////////////////// // Elements / Thread: 4 x 4 // Threads / Warp: 8 x 4 diff --git a/test/unit/gemm/device/simt_sm50.py b/test/unit/gemm/device/simt_sm50.py index 525fa2a8..9c70e910 100644 --- a/test/unit/gemm/device/simt_sm50.py +++ b/test/unit/gemm/device/simt_sm50.py @@ -45,14 +45,15 @@ warpShapeMin = 8*8 threadblockEdgeMax = 256 -# char, type bits/elem, max tile, L0 threadblock tiles +# char, type bits/elem, max tile, L0 threadblock tiles precisions = [ - ["c", "cutlass::complex", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], - ["d", "double", 64, 64*64, [ [ 64, 64], [ 32, 32] ] ], - ["h", "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ], - ["i", "int", 32, 128*128, [ [128, 64], [ 16, 32] ] ], - ["s", "float", 32, 128*128, [ [128, 256], [128, 128], [ 64, 64] ] ], - ["z", "cutlass::complex", 128, 64*64, [ [ 32, 64], [ 16, 32] ] ], + ["c", "cutlass::complex", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], + ["q", "cutlass::Quaternion", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], + ["d", "double", 64, 64*64, [ [ 64, 64], [ 32, 32] ] ], + ["h", "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ], + ["i", "int", 32, 128*128, [ [128, 64], [ 16, 32] ] ], + ["s", "float", 32, 128*128, [ [128, 256], [128, 128], [ 64, 64] ] ], + ["z", "cutlass::complex", 128, 64*64, [ [ 32, 64], [ 16, 32] ] ], ] # L1 will have a single kernel for every unique shape # L2 will have everything else @@ -313,7 +314,7 @@ for precision in precisions: " cutlass::arch::Sm50,\n" " ThreadblockShape, WarpShape, InstructionShape,\n" " EpilogueOutputOp,\n" - " cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,\n" + " cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,\n" " 2 // Stages\n" " >;\n" % ( "Column" if columnMajorA else "Row", diff --git a/test/unit/gemm/device/testbed.h b/test/unit/gemm/device/testbed.h index 24ec13e4..acad2766 100644 --- a/test/unit/gemm/device/testbed.h +++ b/test/unit/gemm/device/testbed.h @@ -45,6 +45,8 @@ #include "testbed_utils.h" +#include "cutlass/layout/matrix.h" + namespace test { namespace gemm { namespace device { @@ -58,6 +60,9 @@ struct Testbed { using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; /// Initialization + typename Gemm::LayoutA::Stride stride_factor_A; + typename Gemm::LayoutB::Stride stride_factor_B; + typename Gemm::LayoutC::Stride stride_factor_C; cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_C; @@ -79,6 +84,23 @@ struct Testbed { cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): + stride_factor_A(typename Gemm::LayoutA::Stride()), + stride_factor_B(typename Gemm::LayoutB::Stride()), + stride_factor_C(typename Gemm::LayoutC::Stride()), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + Testbed( + typename Gemm::LayoutA::Stride stride_factor_A_, + typename Gemm::LayoutB::Stride stride_factor_B_, + typename Gemm::LayoutC::Stride stride_factor_C_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_C(stride_factor_C_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } /// Helper to initialize a tensor view @@ -139,11 +161,11 @@ struct Testbed { // Allocate the GEMM workspace // - tensor_A.resize(problem_size.mk()); - tensor_B.resize(problem_size.kn()); - tensor_C.resize(problem_size.mn()); - tensor_D.resize(problem_size.mn()); - reference_D.resize(problem_size.mn(), false); + tensor_A.resize(problem_size.mk(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), stride_factor_A)); + tensor_B.resize(problem_size.kn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), stride_factor_B)); + tensor_C.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); + tensor_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); + reference_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C), false); EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); @@ -153,7 +175,7 @@ struct Testbed { // in the upper left corner of each operand. tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); - tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = typename Gemm::ElementC(1); cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); @@ -226,7 +248,7 @@ struct Testbed { // // Verify // - + cutlass::reference::host::Gemm< typename Gemm::ElementA, typename Gemm::LayoutA, typename Gemm::ElementB, typename Gemm::LayoutB, @@ -347,7 +369,10 @@ struct Testbed { ///////////////////////////////////////////////////////////////////////////////////////////////// template -bool TestAllGemm() { +bool TestAllGemm( + const typename Gemm::LayoutA::Stride& stride_factor_A = typename Gemm::LayoutA::Stride(), + const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), + const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) { bool passed = true; int const kMinimumOperandElementSize = @@ -393,7 +418,7 @@ bool TestAllGemm() { 2.0 }; - Testbed testbed; + Testbed testbed(stride_factor_A, stride_factor_B, stride_factor_C); using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; @@ -414,7 +439,6 @@ bool TestAllGemm() { for (auto beta : problem_beta) { cutlass::gemm::GemmCoord problem_size(m, n, k); - passed = testbed.run( problem_size, split_k, diff --git a/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/test/unit/gemm/device/testbed_gemm_with_broadcast.h new file mode 100644 index 00000000..d933d41d --- /dev/null +++ b/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -0,0 +1,651 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithBroadcastReferenceOp { + + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + GemmWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { + + ElementCompute z_full = binary_op(gemm, bias); + Z = ElementZ(z_full); + + ElementCompute t_full = elementwise_op(z_full); + T = ElementT(t_full); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = GEMM(AB, C) +// +// Z[i, j] = ReductionOp(Y[i, j], Broadcast[i]) +// +// T[i, j] = Elementwise(Z[i, j]) +// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +struct TestbedGemmWithBroadcast { + + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCOmpute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; // Input A + cutlass::HostTensor tensor_B; // Input B + cutlass::HostTensor tensor_C; // Input C + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + + cutlass::HostTensor tensor_Z; + cutlass::HostTensor tensor_T; + + cutlass::HostTensor tensor_C_ref; + cutlass::HostTensor tensor_Y_ref; + cutlass::HostTensor tensor_Z_ref; + cutlass::HostTensor tensor_T_ref; + + + // + // Methods + // + + TestbedGemmWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_Z.resize(problem_size.mn()); + tensor_T.resize(problem_size.mn()); + tensor_Broadcast.resize({ + problem_size.m(), + 1 + }); + + tensor_C_ref.resize(problem_size.mn()); + tensor_Y_ref.resize(problem_size.mn()); + tensor_Z_ref.resize(problem_size.mn()); + tensor_T_ref.resize(problem_size.mn()); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Broadcast.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { + for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { + tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + + tensor_Z.sync_device(); + tensor_T.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + tensor_Z.sync_host(); + tensor_T.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z_ref.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T_ref.host_view()), 0); + + bool passed = true; + float norm_diff = 0; + + if (OutputOp::kStoreZ) { + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Z_ref.host_view(), tensor_Z.host_view(), float()); + passed = (norm_diff <= 0.1f); + EXPECT_LT(norm_diff, 0.1f) << " tensor_Z is incorrect"; + } + + if (OutputOp::kStoreT) { + + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_T_ref.host_view(), tensor_T.host_view(), float()); + passed = (passed && (norm_diff <= 0.1f)); + + EXPECT_LT(norm_diff, 0.1f) << " tensor_T is incorrect"; + } + + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("errors_testbed_gemm_with_broadcast.txt"); + + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nZ =\n" << tensor_Z.host_view() + << "\nT =\n" << tensor_T.host_view() + << "\n\n" + << "\nY_ref =\n" << tensor_Y_ref.host_view() + << "\nZ_ref =\n" << tensor_Z_ref.host_view() + << "\nT_ref =\n" << tensor_T_ref.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + ElementAccumulator, typename Gemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C_ref.host_ref(), + tensor_Y_ref.host_ref(), + ElementAccumulator(0) + ); + + using ElementC = typename Gemm::ElementC; + + ReferenceOp reference_op; + + + // compute tensor Z and tensor T + for (int m = 0; m < problem_size.m(); ++m) { + for (int n = 0; n < problem_size.n(); ++n) { + + ElementZ z; + ElementT t; + + reference_op(z, t, tensor_Y_ref.at({m, n}), tensor_Broadcast.at({m, 0})); + + tensor_Z_ref.at({m, n}) = z; + tensor_T_ref.at({m, n}) = t; + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementAccumulator alpha = ElementAccumulator(1), + ElementAccumulator beta = ElementAccumulator(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_Z.device_data(), + tensor_Broadcast.device_data(), + tensor_T.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_Z.layout().stride(0), + 0, // This must be zero + tensor_T.layout().stride(0), + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = true; + + passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + // + // Profile + // + + #if 0 // profiling disabled for now. + + int const kWorkspaces = 100; + + cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Broadcast(tensor_Broadcast.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Z(tensor_Z.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_T(tensor_T.capacity() * kWorkspaces); + + cudaEvent_t events[2]; + for (auto & event : events) { + cudaError_t result = cudaEventCreate(&event); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); + return false; + break; + } + } + + int const kWarmupIterations = 5; + int const kProfilingIterations = 100; + + for (int i = 0; i < kWarmupIterations; ++i) { + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + + cudaError_t result = cudaEventRecord(events[0]); + EXPECT_EQ(result, cudaSuccess); + + for (int i = 0; i < kProfilingIterations; ++i) { + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), + profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), + profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), + profiling_tensor_Z.get() + tensor_Z.capacity() * (i % kWorkspaces), + profiling_tensor_Broadcast.get() + tensor_Broadcast.capacity() * (i % kWorkspaces), + profiling_tensor_T.get() + tensor_T.capacity() * (i % kWorkspaces), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_Z.layout().stride(0), + 0, // This must be zero + tensor_T.layout().stride(0), + }; + + gemm_op.initialize(arguments, workspace.get()); + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + result = cudaEventRecord(events[1]); + EXPECT_EQ(result, cudaSuccess); + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess); + + float elapsed_time = 0; + result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); + EXPECT_EQ(result, cudaSuccess); + + double average_time = double(elapsed_time) / double(kProfilingIterations); + + std::cout << problem_size << ": " << average_time << " ms" << std::endl; + + for (auto & event : events) { + cudaEventDestroy(event); + } + #endif + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +bool TestGemmWithBroadcast( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedGemmWithBroadcast testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +bool TestAllGemmWithBroadcast() { + + int M_problems[] = {8, 136, 264, 520}; + int N_problems[] = {8, 136, 264, 520}; + int K_problems[] = {8, 136, 264, 520}; + double alpha_problems[] = {1.25, 2.25}; + double beta_problems[] = {0, 1, 2.0}; + + bool passed = true; + + for (int M : M_problems) { + for (int N : N_problems) { + for (int K : K_problems) { + for (double alpha : alpha_problems) { + for (double beta : beta_problems) { + + TestbedGemmWithBroadcast testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + 1, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + EXPECT_TRUE(passed) + << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta; + + if (!passed) { + + return passed; + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/testbed_gemm_with_reduction.h b/test/unit/gemm/device/testbed_gemm_with_reduction.h new file mode 100644 index 00000000..473cd664 --- /dev/null +++ b/test/unit/gemm/device/testbed_gemm_with_reduction.h @@ -0,0 +1,491 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithReductionReference { + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; + using ElementC = typename Gemm::ElementC; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + // + // Data members + // + + BinaryOp binary_op; + + // + // Methods + // + + GemmWithReductionReference() { } + + ElementCompute operator()( + ElementAccumulator d_y, + ElementT t) { + + return binary_op(ElementCompute(d_y), ElementCompute(t)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp +> +struct TestbedGemmWithReduction { + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_Reduction; + cutlass::HostTensor tensor_Tensor; + cutlass::HostTensor tensor_C_ref; + cutlass::HostTensor reference_d_Y; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Reduction; + + // + // Methods + // + + TestbedGemmWithReduction( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + for (int m = 0; m < view.extent().row(); ++m) { + for (int n = 0; n < view.extent().column(); ++n) { + //view.at({m, n}) = Element(float(((idx ++) % 17) - 8)); + view.at({m, n}) = (n == 0 ? Element(m) : Element()); + + } + } + } + else { + // TODO: Implement the rest + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + + tensor_Reduction.resize({ + problem_size.m(), + (problem_size.n() - 1 + Gemm::ThreadblockShape::kN) / Gemm::ThreadblockShape::kN + }); + + tensor_Tensor.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + reference_d_Y.resize(problem_size.mn(), false); + tensor_C_ref.resize(problem_size.mn(), false); + reference_Reduction.resize({problem_size.m(), 1}, false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Tensor.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { + for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { + tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_Reduction.sync_device(); + tensor_Tensor.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + tensor_Reduction.sync_host(); + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Reduction.host_view()), 0); + + bool passed = true; + for (int m = 0; m < tensor_Reduction.extent().row(); ++m) { + + ElementAccumulator reduced_value = ElementAccumulator(); + for (int j = 0; j < tensor_Reduction.extent().column(); ++j) { + reduced_value += tensor_Reduction.at({m, j}); + } + + if (reduced_value != reference_Reduction.at({m, 0})) { + std::cout << "Error in bias[" << m << "] - Expected: " << reference_Reduction.at({m, 0}) << ", got: " << reduced_value << std::endl; + passed = false; + break; + } + } + EXPECT_TRUE(passed) << "Reduction is incorect."; + + if (!cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view())) { + EXPECT_TRUE(false) << " mismatched reference"; + passed = false; + } + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors_sm70.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nT = \n" << tensor_Tensor.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view() + << "\n\nReduction =\n" << tensor_Reduction.host_view() << "\n" + << "\nReference reduction =\n" << reference_Reduction.host_view() << "\n"; + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + ElementAccumulator, typename Gemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C_ref.host_ref(), + reference_d_Y.host_ref(), + ElementAccumulator(0) + ); + + using ElementC = typename Gemm::ElementC; + + ReferenceOp reference_op; + + // compute backwards + for (int m = 0; m < problem_size.m(); ++m) { + ElementAccumulator reduced_value = ElementAccumulator(); + for (int n = 0; n < problem_size.n(); ++n) { + ElementAccumulator d_full = reference_op(reference_d_Y.at({m, n}), tensor_Tensor.at({m, n})); + reduced_value += d_full; + reference_D.at({m, n}) = ElementC(d_full); + } + reference_Reduction.at({m, 0}) = reduced_value; + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementAccumulator alpha = ElementAccumulator(1), + ElementAccumulator beta = ElementAccumulator(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_Reduction.device_data(), + tensor_Tensor.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_Reduction.layout().stride(0), + tensor_Tensor.layout().stride(0), + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmWithReduction( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count = 1, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedGemmWithReduction testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/testbed_planar_complex.h b/test/unit/gemm/device/testbed_planar_complex.h index 3bc99775..29d1254d 100644 --- a/test/unit/gemm/device/testbed_planar_complex.h +++ b/test/unit/gemm/device/testbed_planar_complex.h @@ -103,8 +103,8 @@ public: cutlass::reference::host::TensorFillRandomUniform( tensor_C.host_view(), seed * 2020, scope_max, scope_min, 0); - cutlass::reference::host::TensorFill(tensor_D.host_view()); - cutlass::reference::host::TensorFill(tensor_D_ref.host_view()); + cutlass::reference::host::TensorFill(tensor_D.host_view(), cutlass::complex()); + cutlass::reference::host::TensorFill(tensor_D_ref.host_view(), cutlass::complex()); tensor_A.sync_device(); tensor_B.sync_device(); @@ -162,10 +162,10 @@ public: ElementC *ptr_C = tensor_C.device_data(); ElementC *ptr_D = tensor_D.device_data(); - int lda = tensor_A.layout().stride(0); - int ldb = tensor_B.layout().stride(0); - int ldc = tensor_C.layout().stride(0); - int ldd = tensor_D.layout().stride(0); + typename LayoutA::Stride::Index lda = tensor_A.layout().stride(0); + typename LayoutB::Stride::Index ldb = tensor_B.layout().stride(0); + typename LayoutC::Stride::Index ldc = tensor_C.layout().stride(0); + typename LayoutC::Stride::Index ldd = tensor_D.layout().stride(0); int64_t imag_stride_A = tensor_A.imaginary_stride(); int64_t imag_stride_B = tensor_B.imaginary_stride(); @@ -266,15 +266,15 @@ template bool TestAllGemmPlanarComplex() { int M[] = { - 16, 264, + 16, 64, 72, 144, 264, 520, }; int N[] = { - 16, 248, + 16, 64, 72, 144, 248, 264, 520 }; int K[] = { - 8, 96, + 8, 64, 72, 96, 264, 520 }; using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; diff --git a/test/unit/gemm/device/testbed_sparse.h b/test/unit/gemm/device/testbed_sparse.h index e2611210..5eb75efe 100644 --- a/test/unit/gemm/device/testbed_sparse.h +++ b/test/unit/gemm/device/testbed_sparse.h @@ -477,4 +477,3 @@ bool TestAllSparseGemm() { } // namespace test ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/test/unit/gemm/threadblock/mma_multistage.cu b/test/unit/gemm/threadblock/mma_multistage.cu index 8e769041..cce15d65 100644 --- a/test/unit/gemm/threadblock/mma_multistage.cu +++ b/test/unit/gemm/threadblock/mma_multistage.cu @@ -22,6 +22,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + /*! \file \brief Unit tests for threadblock-level GEMM */ @@ -3824,4 +3825,5 @@ TEST(SM80_gemm_threadblock_crosswise_f64, } //////////////////////////////////////////////////////////////////////////////// + #endif diff --git a/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h index a947af7f..d915c045 100644 --- a/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h +++ b/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h @@ -59,7 +59,8 @@ __global__ void kernel_multistage_mma_sparse(cutlass::gemm::GemmCoord problem_si typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::Params params_B, typename Mma::IteratorB::TensorRef ref_B, - typename Mma::ElementC *ptr_C, int ldc, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc, typename Mma::IteratorE::Params params_E, typename Mma::IteratorE::TensorRef ref_E) { // Shared storage needed by threadblock-scoped matrix multiply- diff --git a/test/unit/gemm/threadblock/mma_multistage_testbed.h b/test/unit/gemm/threadblock/mma_multistage_testbed.h index 84dfdbdb..7da622df 100644 --- a/test/unit/gemm/threadblock/mma_multistage_testbed.h +++ b/test/unit/gemm/threadblock/mma_multistage_testbed.h @@ -57,7 +57,8 @@ __global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::Params params_B, typename Mma::IteratorB::TensorRef ref_B, - typename Mma::ElementC *ptr_C, int ldc) { + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { // Shared storage needed by threadblock-scoped matrix multiply-accumulate // Dynamic shared memory base pointer diff --git a/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/test/unit/gemm/threadblock/mma_pipelined_testbed.h index ee71c51a..4f8ffd15 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_testbed.h +++ b/test/unit/gemm/threadblock/mma_pipelined_testbed.h @@ -67,7 +67,8 @@ __global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::Params params_B, typename Mma::IteratorB::TensorRef ref_B, - typename Mma::ElementC *ptr_C, int ldc) { + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { // Shared storage needed by threadblock-scoped matrix multiply-accumulate __shared__ typename Mma::SharedStorage shared_storage; diff --git a/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/test/unit/gemm/threadblock/mma_planar_complex_testbed.h index e1b537d5..e2fb8f7a 100644 --- a/test/unit/gemm/threadblock/mma_planar_complex_testbed.h +++ b/test/unit/gemm/threadblock/mma_planar_complex_testbed.h @@ -67,7 +67,8 @@ __global__ void kernel_mma_planar_complex( typename Mma::IteratorB::Params params_B, typename Mma::IteratorB::Element *ptr_B, int64_t imaginary_stride_B, - typename Mma::ElementC *ptr_C, int ldc, int64_t imaginary_stride_C) { + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc, int64_t imaginary_stride_C) { // Shared storage needed by threadblock-scoped matrix multiply-accumulate __shared__ typename Mma::SharedStorage shared_storage; diff --git a/test/unit/gemm/warp/gemm_sm50.cu b/test/unit/gemm/warp/gemm_sm50.cu index 88b84d87..e6a782ef 100644 --- a/test/unit/gemm/warp/gemm_sm50.cu +++ b/test/unit/gemm/warp/gemm_sm50.cu @@ -29,6 +29,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/complex.h" +#include "cutlass/quaternion.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/warp/mma_simt.h" @@ -593,3 +594,55 @@ TEST(SM50_warp_gemm_complex_f64_col_row_row, 32x16x1_1x1x1) { test::gemm::warp::Testbed>().run(); } ///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_warp_gemm_quaternion_f32_col_row_col, 16x8x8_1x1x1) { + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape<8, 4>, + cutlass::layout::ColumnMajorInterleaved<2>, + cutlass::gemm::GemmShape<1, 1, 1> + >; + + using quaternion_f32_t = cutlass::Quaternion; + + using Mma = cutlass::gemm::warp::MmaSimt< + cutlass::gemm::GemmShape<16, 8, 8>, + quaternion_f32_t, + cutlass::layout::ColumnMajor, + quaternion_f32_t, + cutlass::layout::RowMajor, + quaternion_f32_t, + cutlass::layout::ColumnMajor, + Policy + >; + + test::gemm::warp::Testbed>().run(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM50_warp_gemm_quaternion_f32_col_row_row, 16x8x8_1x1x1) { + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape<8, 4>, + cutlass::layout::ColumnMajorInterleaved<2>, + cutlass::gemm::GemmShape<1, 1, 1> + >; + + using quaternion_f32_t = cutlass::Quaternion; + + using Mma = cutlass::gemm::warp::MmaSimt< + cutlass::gemm::GemmShape<16, 8, 8>, + quaternion_f32_t, + cutlass::layout::ColumnMajor, + quaternion_f32_t, + cutlass::layout::RowMajor, + quaternion_f32_t, + cutlass::layout::RowMajor, + Policy + >; + + test::gemm::warp::Testbed>().run(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/warp/gemm_sm80.cu b/test/unit/gemm/warp/gemm_sm80.cu index 32abb541..eedea164 100644 --- a/test/unit/gemm/warp/gemm_sm80.cu +++ b/test/unit/gemm/warp/gemm_sm80.cu @@ -1856,3 +1856,4 @@ TEST(SM80_warp_gemm_tensor_op_canonical_tf32_col_row, 32x32x8_64x32x8_8x8x4) { #endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/warp/testbed.h b/test/unit/gemm/warp/testbed.h index cc5b55b2..5a95ee49 100644 --- a/test/unit/gemm/warp/testbed.h +++ b/test/unit/gemm/warp/testbed.h @@ -33,6 +33,7 @@ #include "cutlass/numeric_types.h" #include "cutlass/subbyte_reference.h" #include "cutlass/platform/platform.h" +#include "cutlass/arch/arch.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" @@ -100,9 +101,9 @@ __global__ void kernel( typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); - typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::LaneId()); + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); - typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::LaneId()); + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); FragmentA frag_A; FragmentB frag_B; @@ -129,7 +130,7 @@ __global__ void kernel( } } - typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::LaneId()); + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); iter_C.store(accum); } @@ -142,7 +143,7 @@ template < typename Mma_, /// Size of threadblock-scoped shape used to store SMEM typename ThreadblockShape_, - /// The innter product operation performed by GEMM + /// The inner product operation performed by GEMM typename Operator_ = cutlass::arch::OpMultiplyAdd > struct Testbed { @@ -205,8 +206,10 @@ struct Testbed { } uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - tensor_A.host_view(), seed, scope_max, scope_min, 0); + + cutlass::reference::host::BlockFillRandomUniform(tensor_A.host_data(), + tensor_A.capacity(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), tensor_A.capacity()); @@ -230,8 +233,10 @@ struct Testbed { } uint64_t seed = 7; - cutlass::reference::host::TensorFillRandomUniform( - tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + + cutlass::reference::host::BlockFillRandomUniform(tensor_B.host_data(), + tensor_B.capacity(), seed, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), tensor_B.capacity()); @@ -313,23 +318,25 @@ struct Testbed { cutlass::TensorView tensor_A_physical( tensor_A.host_data(), - tensor_A.stride(), + tensor_A.stride()[0], tensor_A.extent()); cutlass::TensorView tensor_B_physical( tensor_B.host_data(), - tensor_B.stride(), + tensor_B.stride()[0], tensor_B.extent()); std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "A:\n" << tensor_A.host_view() << "\n\n" - << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + << "A(physical - stride: " << tensor_A.stride()[0] + << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "B:\n" << tensor_B.host_view() << "\n\n" - << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + << "B(physical - stride: " << tensor_B.stride()[0] + << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; std::cout << "C:\n" << tensor_C.host_view() << "\n\n" @@ -493,23 +500,23 @@ struct TestbedComplex { cutlass::TensorView tensor_A_physical( tensor_A.host_data(), - tensor_A.stride(), + tensor_A.stride()[0], tensor_A.extent()); cutlass::TensorView tensor_B_physical( tensor_B.host_data(), - tensor_B.stride(), + tensor_B.stride()[0], tensor_B.extent()); std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "A:\n" << tensor_A.host_view() << "\n\n" - << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "B:\n" << tensor_B.host_view() << "\n\n" - << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; std::cout << "C:\n" << tensor_C.host_view() << "\n\n" @@ -574,9 +581,9 @@ __global__ void kernel_transform( typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); - typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::LaneId()); + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); - typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::LaneId()); + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); FragmentA loaded_frag_A; FragmentB loaded_frag_B; @@ -608,7 +615,7 @@ __global__ void kernel_transform( } } - typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::LaneId()); + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); iter_C.store(accum); } @@ -790,23 +797,23 @@ struct TransformTestbed { cutlass::TensorView tensor_A_physical( tensor_A.host_data(), - tensor_A.stride(), + tensor_A.stride()[0], tensor_A.extent()); cutlass::TensorView tensor_B_physical( tensor_B.host_data(), - tensor_B.stride(), + tensor_B.stride()[0], tensor_B.extent()); std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "A:\n" << tensor_A.host_view() << "\n\n" - << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "B:\n" << tensor_B.host_view() << "\n\n" - << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; std::cout << "C:\n" << tensor_C.host_view() << "\n\n" @@ -970,23 +977,23 @@ struct TransformedTestbedComplex { cutlass::TensorView tensor_A_physical( tensor_A.host_data(), - tensor_A.stride(), + tensor_A.stride()[0], tensor_A.extent()); cutlass::TensorView tensor_B_physical( tensor_B.host_data(), - tensor_B.stride(), + tensor_B.stride()[0], tensor_B.extent()); std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "A:\n" << tensor_A.host_view() << "\n\n" - << "A(physical - stride: " << tensor_A.stride() << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; std::cout << "B:\n" << tensor_B.host_view() << "\n\n" - << "B(physical - stride: " << tensor_B.stride() << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; std::cout << "C:\n" << tensor_C.host_view() << "\n\n" @@ -1073,11 +1080,11 @@ __global__ void sparse_kernel( Mma::Shape::kK / Mma::kSparse / Mma::kElementsPerElementE / Mma::kInterleaved}); - typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::LaneId()); + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); - typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::LaneId()); + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); - typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E}, cutlass::LaneId()); + typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E}, cutlass::arch::LaneId()); FragmentA frag_A; FragmentB frag_B; @@ -1108,7 +1115,7 @@ __global__ void sparse_kernel( } } - typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::LaneId()); + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); iter_C.store(accum); } diff --git a/tools/library/include/cutlass/library/handle.h b/tools/library/include/cutlass/library/handle.h index fe5ac819..844dbfc9 100644 --- a/tools/library/include/cutlass/library/handle.h +++ b/tools/library/include/cutlass/library/handle.h @@ -139,24 +139,24 @@ public: ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices void const * ptr_A, /// Pointer to A matrix in Global Memory - int lda, /// Leading dimension of A matrix + int64_t lda, /// Leading dimension of A matrix NumericTypeID element_B, /// Data type of B matrix elements LayoutTypeID layout_B, /// Layout of B matrix ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices void const * ptr_B, /// Pointer to B matrix in Global Memory - int ldb, /// Leading dimension of B matrix + int64_t ldb, /// Leading dimension of B matrix void const * beta, /// Pointer to beta scalar NumericTypeID element_C, /// Data type of C and D matrices void const * ptr_C, /// Pointer to C matrix - int ldc, /// Leading dimension of C matrix + int64_t ldc, /// Leading dimension of C matrix void * ptr_D, /// Pointer to D matrix - int ldd /// Leading dimension of D matrix + int64_t ldd /// Leading dimension of D matrix ); /// Executes a GEMM computation: D <= alpha * A*B + beta * C. @@ -182,24 +182,24 @@ public: ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices void const * ptr_A, /// Pointer to A matrix in Global Memory - int lda, /// Leading dimension of A matrix + int64_t lda, /// Leading dimension of A matrix NumericTypeID element_B, /// Data type of B matrix elements LayoutTypeID layout_B, /// Layout of B matrix ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices void const * ptr_B, /// Pointer to B matrix in Global Memory - int ldb, /// Leading dimension of B matrix + int64_t ldb, /// Leading dimension of B matrix void const * beta, /// Pointer to beta scalar NumericTypeID element_C, /// Data type of C and D matrices void const * ptr_C, /// Pointer to C matrix - int ldc, /// Leading dimension of C matrix + int64_t ldc, /// Leading dimension of C matrix void * ptr_D, /// Pointer to D matrix - int ldd, /// Leading dimension of D matrix + int64_t ldd, /// Leading dimension of D matrix int batch_count = 1, /// Batch count or number of split-K slices @@ -231,8 +231,8 @@ public: void const * ptr_A_real, /// Pointer to real part of A matrix void const * ptr_A_imag, /// Pointer to imaginary part of A matrix - int lda_real, /// Leading dimension of real part of A matrix - int lda_imag, /// Leading dimension of imaginary part of A matrix + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix NumericTypeID element_B, /// Data type of B matrix elements LayoutTypeID layout_B, /// Layout of B matrix @@ -240,8 +240,8 @@ public: void const * ptr_B_real, /// Pointer to real part of B matrix void const * ptr_B_imag, /// Pointer to imaginary part of B matrix - int ldb_real, /// Leading dimension of real part of B matrix - int ldb_imag, /// Leading dimension of imaginary part of B matrix + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix void const * beta, /// Pointer to beta scalar @@ -249,13 +249,13 @@ public: void const * ptr_C_real, /// Pointer to real part of C matrix void const * ptr_C_imag, /// Pointer to imaginary part of C matrix - int ldc_real, /// Leading dimension of real part of C matrix - int ldc_imag, /// Leading dimension of imaginary part of C matrix + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix void * ptr_D_real, /// Pointer to real part of D matrix void * ptr_D_imag, /// Pointer to imaginary part of D matrix - int ldd_real, /// Leading dimension of real part of D matrix - int ldd_imag, /// Leading dimension of imaginary part of D matrix + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix int batch_count = 1, /// Number of batched GEMMs to execute @@ -297,8 +297,8 @@ public: void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices - int lda_real, /// Leading dimension of real part of A matrix - int lda_imag, /// Leading dimension of imaginary part of A matrix + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix NumericTypeID element_B, /// Data type of B matrix elements LayoutTypeID layout_B, /// Layout of B matrix @@ -307,8 +307,8 @@ public: void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices - int ldb_real, /// Leading dimension of real part of B matrix - int ldb_imag, /// Leading dimension of imaginary part of B matrix + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix void const * beta, /// Pointer to beta scalar @@ -317,14 +317,14 @@ public: void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices - int ldc_real, /// Leading dimension of real part of C matrix - int ldc_imag, /// Leading dimension of imaginary part of C matrix + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices - int ldd_real, /// Leading dimension of real part of D matrix - int ldd_imag /// Leading dimension of imaginary part of D matrix + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag /// Leading dimension of imaginary part of D matrix ); }; diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 18bfce24..14026a35 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -933,13 +933,13 @@ struct Conv2dConfiguration { conv::Conv2dProblemSize problem_size; // stride of operand A - std::vector stride_a; + std::vector stride_a; // stride of operand B - std::vector stride_b; + std::vector stride_b; // stride of operand C - std::vector stride_c; + std::vector stride_c; }; diff --git a/tools/library/scripts/conv2d_operation.py b/tools/library/scripts/conv2d_operation.py index e164bd00..b6757072 100644 --- a/tools/library/scripts/conv2d_operation.py +++ b/tools/library/scripts/conv2d_operation.py @@ -17,7 +17,7 @@ from library import * class Conv2dOperation: # def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \ - stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1): self.operation_kind = OperationKind.Conv2d self.arch = arch diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 5316c5b9..b9bd1b41 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -141,18 +141,19 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t ########################################################################################################### # ConvolutionOperator support variations # ____________________________________________________________________ -# ConvolutionalOperator | Analytic | Optimized +# ConvolutionalOperator | Analytic | Optimized # ____________________________________________________________________ -# | Fprop | (strided) | (strided) -# | Dgrad | (strided, unity*) | (unity) -# | Wgrad | (strided) | (strided) +# | Fprop | (strided) | (strided) +# | Dgrad | (strided, unity*) | (strided, unity) +# | Wgrad | (strided) | (strided) # ____________________________________________________________________ # # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low ########################################################################################################### # Convolution for 2D operations def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \ - conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination): + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): element_a, element_b, element_c, element_epilogue = data_type @@ -169,33 +170,66 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme operations = [] for tile in tile_descriptions: - for conv_kind in conv_kinds: + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop for iterator_algorithm in iterator_algorithms: - A = TensorDescription(element_a, layout[0], alignment) - B = TensorDescription(element_b, layout[1], alignment) - C = TensorDescription(element_c, layout[2], alignment_c) + new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) - # unity stride only for Optimized Dgrad - if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad): - new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor) + manifest.append(new_operation) + operations.append(new_operation) - manifest.append(new_operation) - operations.append(new_operation) + # + # Conv2d Dgrad + # + if ConvKind.Dgrad in conv_kinds: - # strided dgrad is not supported by Optimized Dgrad - if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad): - continue + # Unity stride for Analytic and Optimized Dgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_) - # strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic) - new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Analytic Dgrad + # strided dgrad uses a special threadblock swizzle + # note that SwizzlingFunctor.StridedDgradHorizontal might be + # better for problem sizes with large activation channel count + swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1 + + new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) + + manifest.append(new_operation) + operations.append(new_operation) + + # + # Conv2d Wgrad + # + if ConvKind.Wgrad in conv_kinds: + + # Strided support for Analytic and Optimized Wgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) manifest.append(new_operation) operations.append(new_operation) return operations + # Convolution for 3D operations def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \ conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination): @@ -315,6 +349,11 @@ def GenerateSM50_Simt_complex(manifest, args): for math_inst in math_instructions: tile_descriptions = [ + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), ] @@ -1272,6 +1311,7 @@ def GenerateSM80_TensorOp_16816(manifest, args): TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), @@ -1698,9 +1738,10 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args): TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] @@ -1713,14 +1754,14 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args): operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, 32, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - + for op in operations: if op.tile_description.threadblock_shape[1] >= 128: op.C.alignment = 8 @@ -1934,6 +1975,7 @@ def GenerateSM80_TensorOp_1688(manifest, args): TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), @@ -1993,7 +2035,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args): [16, 8, 8], \ DataType.bf16, DataType.bf16, DataType.f32, \ OpcodeClass.TensorOp, \ - MathOperation.multiply_add_fast_bf16) + MathOperation.multiply_add_fast_bf16), ] min_cc = 80 @@ -2017,6 +2059,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args): TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), @@ -2031,6 +2074,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args): CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, 4) # +# # def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, args): @@ -2155,9 +2199,9 @@ def GenerateSM80_TensorOp_884(manifest, args): alignment_constraints = [1,] tile_descriptions = [ - TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), @@ -2463,6 +2507,7 @@ if __name__ == "__main__": parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file') parser.add_argument('--selected-kernel-list', type=str, default=None, required=False, help='Specify the output log file containing all enabled kernels in this build') + parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels") args = parser.parse_args() diff --git a/tools/library/scripts/library.py b/tools/library/scripts/library.py index 5df09a89..21ef62bf 100644 --- a/tools/library/scripts/library.py +++ b/tools/library/scripts/library.py @@ -437,6 +437,10 @@ class SwizzlingFunctor(enum.Enum): Identity2 = enum_auto() Identity4 = enum_auto() Identity8 = enum_auto() + Horizontal = enum_auto() + StridedDgradIdentity1 = enum_auto() + StridedDgradIdentity4 = enum_auto() + StridedDgradHorizontal = enum_auto() # SwizzlingFunctorTag = { @@ -444,6 +448,10 @@ SwizzlingFunctorTag = { SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', + SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', + SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', + SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', + SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', } ################################################################################################### diff --git a/tools/library/scripts/manifest.py b/tools/library/scripts/manifest.py index 409ec09a..536f97bf 100644 --- a/tools/library/scripts/manifest.py +++ b/tools/library/scripts/manifest.py @@ -101,6 +101,70 @@ void initialize_all_${operation_name}_operations(Manifest &manifest) { self.top_level_file.write(self.epilogue_template) self.top_level_file.close() +class EmitInterfaceLibrary: + def __init__(self, generated_path, operation_count, args): + self.generated_path = generated_path + self.args = args + + + self.prototypes = [] + self.fn_calls = [] + self.operation_count = str(operation_count) + + self.top_level_hdr_template = ''' +/* + Generated by manifest.py - Do not edit. +*/ +''' + self.top_level_prologue = ''' + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +namespace cutlass { +\tnamespace library { + +${prototypes} + +\t\tvoid initialize_all(Manifest &manifest) { +\t\t\tmanifest.reserve(${operation_count});\n\n +${fn_calls} +\t\t\t} + +\t} // namespace library +} // namespace cutlass + +''' + + # + def __enter__(self): + self.top_level_path = os.path.join(self.generated_path, 'initialize_all.cpp') + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.top_level_hdr_template) + + self.source_files = [self.top_level_path,] + + return self + + # + def emit(self, operation_name): + self.prototypes.append(SubstituteTemplate( + "\t\tvoid initialize_all_${operation_kind}_operations(Manifest &manifest);", + {'operation_kind': operation_name})) + self.fn_calls.append(SubstituteTemplate( + "\t\t\tinitialize_all_${operation_kind}_operations(manifest);", + {'operation_kind': operation_name})) + + + + # + def __exit__(self, exception_type, exception_value, traceback): + self.top_level_file.write(SubstituteTemplate(self.top_level_prologue, {'prototypes':"\n".join(self.prototypes), + 'fn_calls':"\n".join(self.fn_calls), + 'operation_count': self.operation_count})) + self.top_level_file.close() + ################################################################################################### ################################################################################################### @@ -150,27 +214,6 @@ class Manifest: self.operation_count = 0 self.operations_by_name = {} - self.top_level_prologue = ''' - -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -namespace cutlass { -namespace library { - -${prototypes} - -void initialize_all(Manifest &manifest) { - -''' - self.top_level_reserve = ' manifest.reserve(${operation_count});\n\n' - self.top_level_epilogue = ''' -} - -} // namespace library -} // namespace cutlass - -''' def get_kernel_filters (self, kernelListFile): @@ -288,6 +331,9 @@ void initialize_all(Manifest &manifest) { operation_emitters = { GeneratorTarget.Library: EmitOperationKindLibrary } + interface_emitters = { + GeneratorTarget.Library: EmitInterfaceLibrary + } generated_path = os.path.join(self.args.curr_build_dir, 'generated') @@ -299,38 +345,20 @@ void initialize_all(Manifest &manifest) { source_files = [] - top_level_path = os.path.join(generated_path, 'initialize_all.cpp') - with open(top_level_path, 'w') as top_level_file: - - if target == GeneratorTarget.Library: - source_files.append(top_level_path) - - prototypes = [] + with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter: for operation_kind, configurations in self.operations.items(): - prototypes.append(SubstituteTemplate( - "void initialize_all_${operation_kind}_operations(Manifest &manifest);", - {'operation_kind': OperationKindNames[operation_kind]})) + iface_emitter.emit(OperationKindNames[operation_kind]) - top_level_file.write(SubstituteTemplate(self.top_level_prologue, - {'prototypes': "\n".join(prototypes)})) + source_files += iface_emitter.source_files - top_level_file.write(SubstituteTemplate( - self.top_level_reserve, {'operation_count': str(self.operation_count)})) - # for each operation kind, emit initializer for all configurations - for operation_kind, configurations in self.operations.items(): - - with operation_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter: - for configuration_name, operations in configurations.items(): - operation_kind_emitter.emit(configuration_name, operations) + # for each operation kind, emit initializer for all configurations + for operation_kind, configurations in self.operations.items(): + with operation_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter: + for configuration_name, operations in configurations.items(): + operation_kind_emitter.emit(configuration_name, operations) - source_files += operation_kind_emitter.source_files - - top_level_file.write(SubstituteTemplate( - " initialize_all_${operation_kind}_operations(manifest);\n", - {'operation_kind': OperationKindNames[operation_kind]})) - - top_level_file.write(self.top_level_epilogue) + source_files += operation_kind_emitter.source_files # write the manifest.cmake file containing paths from all targets manifest_path = os.path.join(generated_path, "manifest.cmake") diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index 5dd2ed29..54939242 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -58,6 +58,8 @@ public: using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + // assuming all tensors use same type for StrideIndex + using StrideIndex = typename Operator::LayoutA::Index; using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; @@ -102,7 +104,7 @@ public: OpcodeClassMap::kId; description_.tile_description.math_instruction.math_operation = - MathOperationMap::kId; + MathOperationMap::kId; description_.tile_description.minimum_compute_capability = ArchMap::kMin; @@ -141,7 +143,6 @@ public: using LayoutC = typename Operator::LayoutC; using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - using OperatorArguments = typename Operator::Arguments; public: @@ -160,10 +161,12 @@ protected: GemmConfiguration const *configuration) { operator_args.problem_size = configuration->problem_size; - operator_args.ref_A = {nullptr, int(configuration->lda)}; - operator_args.ref_B = {nullptr, int(configuration->ldb)}; - operator_args.ref_C = {nullptr, int(configuration->ldc)}; - operator_args.ref_D = {nullptr, int(configuration->ldd)}; + + operator_args.ref_A = {nullptr, configuration->lda}; + operator_args.ref_B = {nullptr, configuration->ldb}; + operator_args.ref_C = {nullptr, configuration->ldc}; + operator_args.ref_D = {nullptr, configuration->ldd}; + operator_args.split_k_slices = configuration->split_k_slices; return Status::kSuccess; @@ -360,11 +363,11 @@ protected: SparseGemmConfiguration const *configuration) { operator_args.problem_size = configuration->problem_size; - operator_args.ref_A = {nullptr, int(configuration->lda)}; - operator_args.ref_B = {nullptr, int(configuration->ldb)}; - operator_args.ref_C = {nullptr, int(configuration->ldc)}; - operator_args.ref_D = {nullptr, int(configuration->ldd)}; - operator_args.ref_E = {nullptr, int(configuration->lde)}; + operator_args.ref_A = {nullptr, configuration->lda}; + operator_args.ref_B = {nullptr, configuration->ldb}; + operator_args.ref_C = {nullptr, configuration->ldc}; + operator_args.ref_D = {nullptr, configuration->ldd}; + operator_args.ref_E = {nullptr, configuration->lde}; return Status::kSuccess; } @@ -562,10 +565,10 @@ protected: operator_args.problem_size = configuration->problem_size; operator_args.batch_count = configuration->batch_count; - operator_args.lda = int(configuration->lda); - operator_args.ldb = int(configuration->ldb); - operator_args.ldc = int(configuration->ldc); - operator_args.ldd = int(configuration->ldd); + operator_args.lda = (configuration->lda); + operator_args.ldb = (configuration->ldb); + operator_args.ldc = (configuration->ldc); + operator_args.ldd = (configuration->ldd); return Status::kSuccess; } @@ -755,14 +758,15 @@ protected: operator_args.problem_size = configuration->problem_size; operator_args.batch_count = configuration->batch_count; - operator_args.lda_real = int(configuration->lda_real); - operator_args.lda_imag = int(configuration->lda_imag); - operator_args.ldb_real = int(configuration->ldb_real); - operator_args.ldb_imag = int(configuration->ldb_imag); - operator_args.ldc_real = int(configuration->ldc_real); - operator_args.ldc_imag = int(configuration->ldc_imag); - operator_args.ldd_real = int(configuration->ldd_real); - operator_args.ldd_imag = int(configuration->ldd_imag); + + operator_args.lda_real = configuration->lda_real; + operator_args.lda_imag = configuration->lda_imag; + operator_args.ldb_real = configuration->ldb_real; + operator_args.ldb_imag = configuration->ldb_imag; + operator_args.ldc_real = configuration->ldc_real; + operator_args.ldc_imag = configuration->ldc_imag; + operator_args.ldd_real = configuration->ldd_real; + operator_args.ldd_imag = configuration->ldd_imag; return Status::kSuccess; } @@ -960,14 +964,14 @@ protected: operator_args.problem_size = configuration->problem_size; operator_args.batch_count = configuration->batch_count; - operator_args.lda_real = int(configuration->lda_real); - operator_args.lda_imag = int(configuration->lda_imag); - operator_args.ldb_real = int(configuration->ldb_real); - operator_args.ldb_imag = int(configuration->ldb_imag); - operator_args.ldc_real = int(configuration->ldc_real); - operator_args.ldc_imag = int(configuration->ldc_imag); - operator_args.ldd_real = int(configuration->ldd_real); - operator_args.ldd_imag = int(configuration->ldd_imag); + operator_args.lda_real = configuration->lda_real; + operator_args.lda_imag = configuration->lda_imag; + operator_args.ldb_real = configuration->ldb_real; + operator_args.ldb_imag = configuration->ldb_imag; + operator_args.ldc_real = configuration->ldc_real; + operator_args.ldc_imag = configuration->ldc_imag; + operator_args.ldd_real = configuration->ldd_real; + operator_args.ldd_imag = configuration->ldd_imag; return Status::kSuccess; } diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index 6108bdc7..92df1d44 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -204,18 +204,18 @@ static int gemm_problem_alignment( int K, NumericTypeID element_A, void const *ptr_A, - int lda, + int64_t lda, int64_t batch_stride_A, NumericTypeID element_B, void const *ptr_B, - int ldb, + int64_t ldb, int64_t batch_stride_B, NumericTypeID element_C, void const * ptr_C, - int ldc, + int64_t ldc, int64_t batch_stride_C, void const * ptr_D, - int ldd, + int64_t ldd, int64_t batch_stride_D, int max_alignment_in_bytes = 16 ) { @@ -338,24 +338,24 @@ Status Handle::gemm( ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices void const * ptr_A, /// Pointer to A matrix in Global Memory - int lda, /// Leading dimension of A matrix + int64_t lda, /// Leading dimension of A matrix NumericTypeID element_B, /// Data type of B matrix elements LayoutTypeID layout_B, /// Layout of B matrix ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices void const * ptr_B, /// Pointer to B matrix in Global Memory - int ldb, /// Leading dimension of B matrix + int64_t ldb, /// Leading dimension of B matrix void const * beta, /// Pointer to beta scalar NumericTypeID element_C, /// Data type of C and D matrices void const * ptr_C, /// Pointer to C matrix - int ldc, /// Leading dimension of C matrix + int64_t ldc, /// Leading dimension of C matrix void * ptr_D, /// Pointer to D matrix - int ldd /// Leading dimension of D matrix + int64_t ldd /// Leading dimension of D matrix ) { // @@ -494,24 +494,24 @@ Status Handle::gemm_universal( ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices void const * ptr_A, /// Pointer to A matrix in Global Memory - int lda, /// Leading dimension of A matrix + int64_t lda, /// Leading dimension of A matrix NumericTypeID element_B, /// Data type of B matrix elements LayoutTypeID layout_B, /// Layout of B matrix ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices void const * ptr_B, /// Pointer to B matrix in Global Memory - int ldb, /// Leading dimension of B matrix + int64_t ldb, /// Leading dimension of B matrix void const * beta, /// Pointer to beta scalar NumericTypeID element_C, /// Data type of C and D matrices void const * ptr_C, /// Pointer to C matrix - int ldc, /// Leading dimension of C matrix + int64_t ldc, /// Leading dimension of C matrix void * ptr_D, /// Pointer to D matrix - int ldd, /// Leading dimension of D matrix + int64_t ldd, /// Leading dimension of D matrix int batch_count, /// Batch count or number of split-K slices @@ -672,8 +672,8 @@ Status Handle::gemm_planar_complex( void const * ptr_A_real, /// Pointer to real part of A matrix void const * ptr_A_imag, /// Pointer to imaginary part of A matrix - int lda_real, /// Leading dimension of real part of A matrix - int lda_imag, /// Leading dimension of imaginary part of A matrix + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix NumericTypeID element_B, /// Data type of B matrix elements LayoutTypeID layout_B, /// Layout of B matrix @@ -681,8 +681,8 @@ Status Handle::gemm_planar_complex( void const * ptr_B_real, /// Pointer to real part of B matrix void const * ptr_B_imag, /// Pointer to imaginary part of B matrix - int ldb_real, /// Leading dimension of real part of B matrix - int ldb_imag, /// Leading dimension of imaginary part of B matrix + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix void const * beta, /// Pointer to beta scalar @@ -690,13 +690,13 @@ Status Handle::gemm_planar_complex( void const * ptr_C_real, /// Pointer to real part of C matrix void const * ptr_C_imag, /// Pointer to imaginary part of C matrix - int ldc_real, /// Leading dimension of real part of C matrix - int ldc_imag, /// Leading dimension of imaginary part of C matrix + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix void * ptr_D_real, /// Pointer to real part of D matrix void * ptr_D_imag, /// Pointer to imaginary part of D matrix - int ldd_real, /// Leading dimension of real part of D matrix - int ldd_imag, /// Leading dimension of imaginary part of D matrix + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix int batch_count, /// Number of batched GEMMs to execute @@ -877,8 +877,8 @@ Status Handle::gemm_planar_complex_array( void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices - int lda_real, /// Leading dimension of real part of A matrix - int lda_imag, /// Leading dimension of imaginary part of A matrix + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix NumericTypeID element_B, /// Data type of B matrix elements LayoutTypeID layout_B, /// Layout of B matrix @@ -887,8 +887,8 @@ Status Handle::gemm_planar_complex_array( void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices - int ldb_real, /// Leading dimension of real part of B matrix - int ldb_imag, /// Leading dimension of imaginary part of B matrix + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix void const * beta, /// Pointer to beta scalar @@ -897,14 +897,14 @@ Status Handle::gemm_planar_complex_array( void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices - int ldc_real, /// Leading dimension of real part of C matrix - int ldc_imag, /// Leading dimension of imaginary part of C matrix + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices - int ldd_real, /// Leading dimension of real part of D matrix - int ldd_imag /// Leading dimension of imaginary part of D matrix + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag /// Leading dimension of imaginary part of D matrix ) { // diff --git a/tools/library/src/reference/conv_reference_operation.h b/tools/library/src/reference/conv_reference_operation.h index 811621c1..204706a1 100644 --- a/tools/library/src/reference/conv_reference_operation.h +++ b/tools/library/src/reference/conv_reference_operation.h @@ -115,13 +115,19 @@ struct ConvReferenceDispatcher< layout::TensorNHWC layout_c; layout_a.stride() = - make_Coord(config.stride_a[0], config.stride_a[1], config.stride_a[2]); + make_Coord(int32_t(config.stride_a[0]), + int32_t(config.stride_a[1]), + int32_t(config.stride_a[2])); layout_b.stride() = - make_Coord(config.stride_b[0], config.stride_b[1], config.stride_b[2]); + make_Coord(int32_t(config.stride_b[0]), + int32_t(config.stride_b[1]), + int32_t(config.stride_b[2])); layout_c.stride() = - make_Coord(config.stride_c[0], config.stride_c[1], config.stride_c[2]); + make_Coord(int32_t(config.stride_c[0]), + int32_t(config.stride_c[1]), + int32_t(config.stride_c[2])); if (kProvider == Provider::kReferenceHost) { diff --git a/tools/profiler/src/conv2d_operation_profiler.h b/tools/profiler/src/conv2d_operation_profiler.h index 2f99b67c..61ea1d8e 100644 --- a/tools/profiler/src/conv2d_operation_profiler.h +++ b/tools/profiler/src/conv2d_operation_profiler.h @@ -274,9 +274,9 @@ public: library::LayoutTypeID const &layout_a, library::LayoutTypeID const &layout_b, library::LayoutTypeID const &layout_c) { - std::vector stride_activations; - std::vector stride_filters; - std::vector stride_output; + std::vector stride_activations; + std::vector stride_filters; + std::vector stride_output; // Strides for interleaved fprop if (conv_kind == library::ConvKind::kFprop && diff --git a/tools/profiler/src/conv3d_operation_profiler.h b/tools/profiler/src/conv3d_operation_profiler.h index 2192a984..f41f8cae 100644 --- a/tools/profiler/src/conv3d_operation_profiler.h +++ b/tools/profiler/src/conv3d_operation_profiler.h @@ -268,7 +268,7 @@ public: A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } // Returns stride vector for tensor A - std::vector stride_a(library::ConvKind const &conv_kind) { + std::vector stride_a(library::ConvKind const &conv_kind) { return { configuration.layout_a(conv_kind).stride()[0], configuration.layout_a(conv_kind).stride()[1], @@ -278,7 +278,7 @@ public: } // Returns stride vector for tensor B - std::vector stride_b(library::ConvKind const &conv_kind) { + std::vector stride_b(library::ConvKind const &conv_kind) { return { configuration.layout_b(conv_kind).stride()[0], @@ -289,7 +289,7 @@ public: } // Returns stride vector for tensor C - std::vector stride_c(library::ConvKind const &conv_kind) { + std::vector stride_c(library::ConvKind const &conv_kind) { return { configuration.layout_c(conv_kind).stride()[0], diff --git a/tools/profiler/src/cutlass_profiler.cu b/tools/profiler/src/cutlass_profiler.cu index c53e8c22..ae1e4c8a 100644 --- a/tools/profiler/src/cutlass_profiler.cu +++ b/tools/profiler/src/cutlass_profiler.cu @@ -67,6 +67,15 @@ CutlassProfiler::~CutlassProfiler() { /// Execute the program int CutlassProfiler::operator()() { + if (options_.cmdline.num_naked_args() > 0) { + std::cerr << "Unknown args: \n"; + options_.cmdline.print_naked_args(std::cerr); + std::cerr << "\n\n\n"; + + print_usage_(std::cout); + return 1; + } + if (options_.about.help) { if (options_.operation_kind == library::OperationKind::kInvalid) { print_usage_(std::cout); diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 38a4acbe..7a582562 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -54,7 +54,7 @@ size_t DeviceAllocation::bytes(library::NumericTypeID type, size_t capacity) { ///////////////////////////////////////////////////////////////////////////////////////////////// template -static std::vector get_packed_layout_stride(std::vector const &extent) { +static std::vector get_packed_layout_stride(std::vector const &extent) { typename Layout::TensorCoord extent_coord; typename Layout::Stride stride_coord; @@ -67,25 +67,25 @@ static std::vector get_packed_layout_stride(std::vector const &extent) extent_coord[i] = extent.at(i); } - std::vector stride; + std::vector stride; stride.resize(Layout::kStrideRank, 0); Layout layout = Layout::packed(extent_coord); stride_coord = layout.stride(); for (int i = 0; i < Layout::kStrideRank; ++i) { - stride.at(i) = stride_coord[i]; + stride.at(i) = (int64_t)stride_coord[i]; } return stride; } /// Returns the stride of a packed layout -std::vector DeviceAllocation::get_packed_layout( +std::vector DeviceAllocation::get_packed_layout( library::LayoutTypeID layout_id, std::vector const &extent) { - std::vector stride; + std::vector stride; switch (layout_id) { case library::LayoutTypeID::kColumnMajor: @@ -159,7 +159,7 @@ static size_t construct_layout_( void *bytes, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector &stride) { + std::vector &stride) { if (extent.size() != Layout::kRank) { throw std::runtime_error( @@ -183,7 +183,7 @@ static size_t construct_layout_( typename Layout::Stride stride_coord; for (int i = 0; i < Layout::kStrideRank; ++i) { - stride_coord[i] = stride.at(i); + stride_coord[i] = (int)stride.at(i); } typename Layout::TensorCoord extent_coord; @@ -210,7 +210,7 @@ size_t DeviceAllocation::construct_layout( void *bytes, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector &stride) { + std::vector &stride) { switch (layout_id) { case library::LayoutTypeID::kColumnMajor: @@ -309,7 +309,7 @@ DeviceAllocation::DeviceAllocation( library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride, + std::vector const &stride, int batch_count ): type_(type), batch_stride_(size_t(0)), capacity_(size_t(0)), pointer_(nullptr), batch_count_(1) { @@ -370,12 +370,12 @@ DeviceAllocation &DeviceAllocation::reset( library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride, + std::vector const &stride, int batch_count) { reset(); - tensor_ref_buffer_.resize(sizeof(pointer_) + (sizeof(int) * library::get_layout_stride_rank(layout_id)), 0); + tensor_ref_buffer_.resize(sizeof(pointer_) + (sizeof(int64_t) * library::get_layout_stride_rank(layout_id)), 0); type_ = type; @@ -422,7 +422,7 @@ library::LayoutTypeID DeviceAllocation::layout() const { return layout_; } -std::vector const & DeviceAllocation::stride() const { +std::vector const & DeviceAllocation::stride() const { return stride_; } @@ -1277,6 +1277,15 @@ struct vector_to_coord { vector_to_coord(coord, vec); } } + + vector_to_coord(TensorCoord &coord, std::vector const &vec) { + + coord[Rank - 1] = (int)vec.at(Rank - 1); + + if (Rank > 1) { + vector_to_coord(coord, vec); + } + } }; /// Permits copying dynamic vectors into static-length vectors @@ -1287,6 +1296,11 @@ struct vector_to_coord { coord[0] = vec.at(0); } + + vector_to_coord(TensorCoord &coord, std::vector const &vec) { + + coord[0] = (int)vec.at(0); + } }; /// Permits copying dynamic vectors into static-length vectors @@ -1306,7 +1320,7 @@ static void write_tensor_csv_static_tensor_view( DeviceAllocation &allocation) { Coord extent; - Coord stride; + Coord stride; if (allocation.extent().size() != Layout::kRank) { throw std::runtime_error("Allocation extent has invalid rank"); @@ -1317,7 +1331,8 @@ static void write_tensor_csv_static_tensor_view( } vector_to_coord, Layout::kRank>(extent, allocation.extent()); - vector_to_coord, Layout::kStrideRank>(stride, allocation.stride()); + vector_to_coord, + Layout::kStrideRank>(stride, allocation.stride()); Layout layout(stride); HostTensor host_tensor(extent, layout, false); @@ -1498,6 +1513,162 @@ void DeviceAllocation::write_tensor_csv( } } +template +static void tensor_fill_tensor_view(DeviceAllocation &allocation, Element val = Element()) { + Coord extent; + Coord stride; + + if (allocation.extent().size() != Layout::kRank) { + throw std::runtime_error("Allocation extent has invalid rank"); + } + + if (allocation.stride().size() != Layout::kStrideRank) { + throw std::runtime_error("Allocation stride has invalid rank"); + } + + vector_to_coord, Layout::kRank>(extent, allocation.extent()); + vector_to_coord, + Layout::kStrideRank>(stride, allocation.stride()); + + TensorView view( + static_cast(allocation.data()), + Layout(stride), + extent + ); + + + cutlass::reference::device::TensorFill( + view, + val + ); +} + +template +static void tensor_fill(DeviceAllocation &allocation, Element val = Element()) { + switch (allocation.layout()) { + case library::LayoutTypeID::kRowMajor: + tensor_fill_tensor_view(allocation, val); + break; + case library::LayoutTypeID::kColumnMajor: + tensor_fill_tensor_view(allocation, val); + break; + case library::LayoutTypeID::kTensorNHWC: + tensor_fill_tensor_view(allocation, val); + break; + case library::LayoutTypeID::kTensorNDHWC: + tensor_fill_tensor_view(allocation, val); + break; + case library::LayoutTypeID::kTensorNC32HW32: + tensor_fill_tensor_view>(allocation, val); + break; + case library::LayoutTypeID::kTensorNC64HW64: + tensor_fill_tensor_view>(allocation, val); + break; + case library::LayoutTypeID::kTensorC32RSK32: + tensor_fill_tensor_view>(allocation, val); + break; + case library::LayoutTypeID::kTensorC64RSK64: + tensor_fill_tensor_view>(allocation, val); + break; + default: + throw std::runtime_error("Unsupported layout"); + break; + } +} + +/// Fills a tensor uniformly with a value (most frequently used to clear the tensor) +void DeviceAllocation::fill(double val = 0.0) { + + switch (this->type()) { + case library::NumericTypeID::kF16: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kBF16: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kTF32: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kF32: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kF64: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kS2: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kS4: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kS8: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kS16: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kS32: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kS64: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kB1: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kU2: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kU4: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kU8: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kU16: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kU32: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kU64: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kCF16: + tensor_fill >(*this, from_real(val)); + break; + + case library::NumericTypeID::kCF32: + tensor_fill >(*this, from_real(val)); + break; + + case library::NumericTypeID::kCF64: + tensor_fill >(*this, from_real(val)); + break; + + default: + throw std::runtime_error("Unsupported numeric type"); + } +} + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace profiler diff --git a/tools/profiler/src/device_allocation.h b/tools/profiler/src/device_allocation.h index 0aa9d0ec..f44afcc8 100644 --- a/tools/profiler/src/device_allocation.h +++ b/tools/profiler/src/device_allocation.h @@ -64,7 +64,7 @@ private: library::LayoutTypeID layout_; /// Stride vector - std::vector stride_; + std::vector stride_; /// Extent vector std::vector extent_; @@ -84,7 +84,7 @@ public: static size_t bytes(library::NumericTypeID type, size_t capacity); /// Returns the stride of a packed layout - static std::vector get_packed_layout( + static std::vector get_packed_layout( library::LayoutTypeID layout_id, std::vector const &extent); @@ -93,7 +93,7 @@ public: void *bytes, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector &stride); + std::vector &stride); /// Returns true if two blocks have exactly the same value static bool block_compare_equal( @@ -124,7 +124,7 @@ public: library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride = std::vector(), + std::vector const &stride = std::vector(), int batch_count = 1); ~DeviceAllocation(); @@ -139,7 +139,7 @@ public: library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride = std::vector(), + std::vector const &stride = std::vector(), int batch_count = 1); /// Returns a buffer owning the tensor reference @@ -162,7 +162,7 @@ public: library::LayoutTypeID layout() const; /// Gets the stride vector - std::vector const & stride() const; + std::vector const & stride() const; /// Gets the extent vector std::vector const & extent() const; @@ -193,6 +193,9 @@ public: /// Initializes a host allocation to a random distribution using std::cout void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits); + + /// Uniformly fills a tensor with a value when provided o.w. zero + void fill(double value); /// Copies from an equivalent-sized tensor in device memory void copy_from_device(void const *ptr); diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu index 3ab6b4c7..2437059b 100644 --- a/tools/profiler/src/device_context.cu +++ b/tools/profiler/src/device_context.cu @@ -52,7 +52,7 @@ DeviceAllocation *DeviceContext::allocate_tensor( library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride, + std::vector const &stride, int batch_count) { device_memory_.emplace_back(type, layout_id, extent, stride, batch_count); @@ -69,7 +69,7 @@ DeviceAllocation *DeviceContext::allocate_tensor( library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride, + std::vector const &stride, int batch_count) { DeviceAllocation *allocation = @@ -133,7 +133,7 @@ DeviceAllocation *DeviceContext::allocate_sparsemeta_tensor( library::LayoutTypeID layout_id, library::NumericTypeID type_a, std::vector const &extent, - std::vector const &stride, + std::vector const &stride, int batch_count) { DeviceAllocation *allocation = diff --git a/tools/profiler/src/device_context.h b/tools/profiler/src/device_context.h index 5e74f07e..1f2a32ed 100644 --- a/tools/profiler/src/device_context.h +++ b/tools/profiler/src/device_context.h @@ -77,7 +77,7 @@ public: library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride = std::vector(), + std::vector const &stride = std::vector(), int batch_count = 1); /// Allocates memory of a given type, capacity (elements), and name @@ -87,7 +87,7 @@ public: library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride = std::vector(), + std::vector const &stride = std::vector(), int batch_count = 1); /// Allocates memory for sparse meta data @@ -98,7 +98,7 @@ public: library::LayoutTypeID layout_id, library::NumericTypeID type_a, std::vector const &extent, - std::vector const &stride = std::vector(), + std::vector const &stride = std::vector(), int batch_count = 1); /// Clears named allocations (but does not necessarily free memory) diff --git a/tools/util/include/cutlass/util/command_line.h b/tools/util/include/cutlass/util/command_line.h index 31187a79..2c818293 100644 --- a/tools/util/include/cutlass/util/command_line.h +++ b/tools/util/include/cutlass/util/command_line.h @@ -90,11 +90,19 @@ struct CommandLine { /** * Returns number of naked (non-flag and non-key-value) commandline parameters */ - template - int num_naked_args() const { + size_t num_naked_args() const { return args.size(); } + /** + * Print naked (non-flag and non-key-value) commandline parameters + */ + void print_naked_args(std::ostream &out) const { + for (auto arg : args) { + out << " " << arg <<"\n"; + } + } + /** * Returns the commandline parameter for a given index (not including flags) */ diff --git a/tools/util/include/cutlass/util/host_tensor.h b/tools/util/include/cutlass/util/host_tensor.h index f105434f..804dcfa2 100644 --- a/tools/util/include/cutlass/util/host_tensor.h +++ b/tools/util/include/cutlass/util/host_tensor.h @@ -325,12 +325,12 @@ public: } /// Returns the layout object's stride in a given physical dimension - Index stride(int dim) const { + LongIndex stride(int dim) const { return layout_.stride().at(dim); } /// Returns the layout object's stride in a given physical dimension - Index & stride(int dim) { + LongIndex & stride(int dim) { return layout_.stride().at(dim); } diff --git a/tools/util/include/cutlass/util/index_sequence.h b/tools/util/include/cutlass/util/index_sequence.h new file mode 100644 index 00000000..ead17a10 --- /dev/null +++ b/tools/util/include/cutlass/util/index_sequence.h @@ -0,0 +1,52 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/cutlass.h" + +/** + * \file + * \brief C++11 version of index_sequence. + */ + +namespace cutlass { + +template +struct index_sequence; + +template +struct index_sequence_helper : index_sequence_helper {}; + +template +struct index_sequence_helper<0, 0, Next...> { + using type = index_sequence<0, Next...>; +}; + +template +using make_index_sequence = typename index_sequence_helper::type; + +} // namespace cutlass diff --git a/tools/util/include/cutlass/util/reference/device/kernel/gemm.h b/tools/util/include/cutlass/util/reference/device/kernel/gemm.h index 0e5c668e..f22c75ef 100644 --- a/tools/util/include/cutlass/util/reference/device/kernel/gemm.h +++ b/tools/util/include/cutlass/util/reference/device/kernel/gemm.h @@ -65,8 +65,8 @@ __global__ void Gemm( // Map each thread to a unique tile of the output matrix MatrixCoord output_coord( - (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow, - (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn + MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), + MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) ); // Compute the general matrix product diff --git a/tools/util/include/cutlass/util/reference/host/convolution.h b/tools/util/include/cutlass/util/reference/host/convolution.h index f69ba174..3e4cb198 100644 --- a/tools/util/include/cutlass/util/reference/host/convolution.h +++ b/tools/util/include/cutlass/util/reference/host/convolution.h @@ -39,6 +39,7 @@ #include "cutlass/conv/convolution.h" #include "cutlass/conv/conv2d_problem_size.h" #include "cutlass/conv/conv3d_problem_size.h" +#include namespace cutlass { namespace reference { @@ -243,7 +244,21 @@ void Conv2dDgrad( p = p / problem_size.stride_h; q = q / problem_size.stride_w; - +#if 0 + std::cout << "row:" + << n * problem_size.H * problem_size.W + + h * problem_size.W + + w << " " + << "n, p, q: (" + << n << ", " + << p << ", " + << q << ") * " + << "r, s: (" + << r << ", " + << s << ") [" + << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]" + << std::endl; +#endif if (p < problem_size.P && q < problem_size.Q) { ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); diff --git a/tools/util/include/cutlass/util/reference/host/error_metrics.h b/tools/util/include/cutlass/util/reference/host/error_metrics.h new file mode 100644 index 00000000..d0f50a90 --- /dev/null +++ b/tools/util/include/cutlass/util/reference/host/error_metrics.h @@ -0,0 +1,60 @@ + +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/util/reference/host/tensor_reduce.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorRelativeErrorMetric( + TensorView view_A_computed, + TensorView view_B_reference, + ComputeType identity = ComputeType() +) { + + return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) / + cutlass::reference::host::TensorNorm(view_B_reference, identity); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index 7904b746..3eea7747 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -36,6 +36,7 @@ // Cutlass includes #include "cutlass/cutlass.h" #include "cutlass/complex.h" +#include "cutlass/quaternion.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/subbyte_reference.h" @@ -219,6 +220,56 @@ struct RandomGaussianFunc > { } }; +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1 + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(rnd / double(1 << int_scale)); + } + else { + reals[i] = from_real(rnd); + } + } + + return Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + /// Computes a random Gaussian distribution template < typename Element, ///< Element type @@ -429,6 +480,58 @@ struct RandomUniformFunc > { } }; +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + /// Computes a random Gaussian distribution template < typename Element, ///< Element type @@ -510,6 +613,32 @@ void TensorFillRandomUniform( TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits); } + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView, Layout> dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc> random_func(seed, max, min, bits); + + detail::TensorFillRandomUniformFunc, Layout> func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /// Fills a tensor with random values with a uniform random distribution. template <