13
CHANGELOG.md
13
CHANGELOG.md
@ -1,5 +1,18 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
|
||||
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
|
||||
* [CuTe](/media/docs/cute/00_quickstart.md), a [new core library and backend](/include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
* [A new conceptual operation hierarchy](media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](/media/docs/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](media/docs/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* Updates to [Compatibility](/README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](/README.md#Target-Architecture).
|
||||
* New warp-specialized GEMM [kernel schedules](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
|
||||
* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations.
|
||||
* [CUTLASS library integration](/tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
|
||||
* Support for [Hopper GEMMs](examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
|
||||
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](examples/48_hopper_warp_specialized_gemm), [49](examples/49_hopper_gemm_schedules_with_collective_builder), and [50](examples/50_hopper_gemm_with_epilogue_swizzle).
|
||||
|
||||
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
|
||||
* [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||
|
||||
84
CITATION.cff
84
CITATION.cff
@ -5,33 +5,61 @@ message: >-
|
||||
following metadata.
|
||||
type: software
|
||||
authors:
|
||||
- given-names: Andrew
|
||||
email: akerr@nvidia.com
|
||||
family-names: Kerr
|
||||
- given-names: Vijay
|
||||
family-names: Thakkar
|
||||
email: vithakkar@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Pradeep
|
||||
family-names: Ramani
|
||||
email: prramani@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Cris
|
||||
family-names: Cecka
|
||||
email: ccecka@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Aniket
|
||||
family-names: Shivam
|
||||
email: ashivam@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Honghao
|
||||
family-names: Lu
|
||||
email: honghaol@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Ethan
|
||||
family-names: Yan
|
||||
email: etyan@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Jack
|
||||
family-names: Kosaian
|
||||
email: jkosaian@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Mark
|
||||
family-names: Hoemmen
|
||||
email: mhoemmen@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Haicheng
|
||||
family-names: Wu
|
||||
affiliation: NVIDIA
|
||||
email: haichengw@nvidia.com
|
||||
- given-names: Manish
|
||||
family-names: Gupta
|
||||
affiliation: Google
|
||||
email: manigupta@google.com
|
||||
- given-names: Dustyn
|
||||
family-names: Blasig
|
||||
email: dblasig@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Pradeep
|
||||
family-names: Ramini
|
||||
email: prramani@nvidia.com
|
||||
- given-names: Andrew
|
||||
family-names: Kerr
|
||||
email: akerr@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Matt
|
||||
family-names: Nicely
|
||||
email: mnicely@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Duane
|
||||
family-names: Merrill
|
||||
email: dumerrill@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Aniket
|
||||
family-names: Shivam
|
||||
email: ashivam@nvidia.com
|
||||
- given-names: Dustyn
|
||||
family-names: Blasig
|
||||
email: dblasig@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Fengqi
|
||||
family-names: Qiao
|
||||
email: fqiao@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Piotr
|
||||
family-names: Majcher
|
||||
@ -49,10 +77,12 @@ authors:
|
||||
family-names: Wang
|
||||
email: jinw@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Matt
|
||||
family-names: Nicely
|
||||
email: mnicely@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Manish
|
||||
family-names: Gupta
|
||||
affiliation: Google
|
||||
email: manigupta@google.com
|
||||
|
||||
|
||||
repository-code: 'https://github.com/NVIDIA/cutlass'
|
||||
abstract: >-
|
||||
CUTLASS is a collection of CUDA C++ template
|
||||
@ -71,12 +101,12 @@ abstract: >-
|
||||
flexibility simplifies their use as building blocks
|
||||
within custom kernels and applications.
|
||||
keywords:
|
||||
- 'cutlass, tensor cores, cuda'
|
||||
- 'cutlass, tensor cores, cuda, cute, nvidia, gpu, linear algebra, matrix computations'
|
||||
license: BSD-3-Clause
|
||||
license-url: https://github.com/NVIDIA/cutlass/blob/v2.11.0/LICENSE.txt
|
||||
version: '2.11.0'
|
||||
date-released: '2022-11-19'
|
||||
license-url: https://github.com/NVIDIA/cutlass/blob/v3.0.0/LICENSE.txt
|
||||
version: '3.0.0'
|
||||
date-released: '2023-01-23'
|
||||
identifiers:
|
||||
- type: url
|
||||
value: "https://github.com/NVIDIA/cutlass/tree/v2.11.0"
|
||||
description: The GitHub release URL of tag 2.11.0
|
||||
value: "https://github.com/NVIDIA/cutlass/tree/v3.0.0"
|
||||
description: The GitHub release URL of tag 3.0.0
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.12.4 FATAL_ERROR)
|
||||
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
|
||||
|
||||
if(cutlass_LOADED)
|
||||
# If CUTLASS has been previously fetched and loaded, don't do it again.
|
||||
@ -39,35 +39,40 @@ endif()
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
|
||||
|
||||
project(CUTLASS VERSION 2.11.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 3.0.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 10.2)
|
||||
message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 10.2 or higher, and strongly recommends CUDA 11.0 or higher.")
|
||||
elseif (CUDA_VERSION VERSION_LESS 11.0)
|
||||
message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.0 or higher.")
|
||||
if (CUDA_VERSION VERSION_LESS 11.3)
|
||||
message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 11.4 or higher, and strongly recommends CUDA 11.8 or higher.")
|
||||
elseif (CUDA_VERSION VERSION_LESS 11.4)
|
||||
message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.8 or higher.")
|
||||
endif()
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.5)
|
||||
message(FATAL_ERROR "GCC version must be at least 7.5!")
|
||||
endif()
|
||||
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
|
||||
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
|
||||
endif()
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
|
||||
#
|
||||
# CUTLASS 2.x requires C++11
|
||||
# CUTLASS 3.x requires C++17
|
||||
#
|
||||
if (NOT IMPLICIT_CMAKE_CXX_STANDARD)
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
endif()
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA)
|
||||
set(CMAKE_CUDA_STANDARD 11)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
|
||||
else()
|
||||
if (NOT IMPLICIT_CMAKE_CXX_STANDARD)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++11)
|
||||
endif()
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++17)
|
||||
endif()
|
||||
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE)
|
||||
endif()
|
||||
@ -107,29 +112,14 @@ if (CUTLASS_ENABLE_TESTS)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
|
||||
if (NOT CUDA_VERSION VERSION_LESS 7.5)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53)
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70 72 75 80 86 87)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 8.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61)
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 89 90)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 9.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 9.2)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 72)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 10.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.1 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90)
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90a)
|
||||
endif()
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
|
||||
@ -271,6 +261,7 @@ if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)
|
||||
endif()
|
||||
|
||||
|
||||
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
|
||||
# MSVC flow handles caching already, but for other generators we handle it here.
|
||||
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
|
||||
@ -288,6 +279,15 @@ if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_OPENMP_TESTS)
|
||||
find_package(OpenMP)
|
||||
if(OpenMP_CXX_FOUND)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=${OpenMP_CXX_FLAGS})
|
||||
else()
|
||||
message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-Wconversion>)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-fno-strict-aliasing>)
|
||||
|
||||
@ -313,10 +313,6 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" )
|
||||
endif()
|
||||
|
||||
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
|
||||
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
|
||||
endif()
|
||||
|
||||
# There are numerous Clang versions that can work with each CUDA toolkit and the
|
||||
# the checks are not very useful so we are turning them off and using testing to
|
||||
# ensure the various combinations work properly.
|
||||
@ -341,6 +337,7 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wl,--disable-new-dtags)
|
||||
|
||||
link_libraries(nvidia::cudart)
|
||||
link_libraries(nvidia::cuda_driver)
|
||||
endif()
|
||||
|
||||
# Support for 128-bit integers if using NVIDIA C++ compiler
|
||||
@ -530,6 +527,8 @@ target_include_directories(
|
||||
$<BUILD_INTERFACE:${CUTLASS_INCLUDE_DIR}>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/include>
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
|
||||
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/examples>
|
||||
)
|
||||
|
||||
install(
|
||||
|
||||
122
CONTRIBUTORS.md
122
CONTRIBUTORS.md
@ -7,63 +7,77 @@
|
||||
This is the official list of CUTLASS developers and contributors.
|
||||
|
||||
## DEVELOPERS
|
||||
Andrew Kerr
|
||||
Haicheng Wu
|
||||
Manish Gupta
|
||||
Dustyn Blasig
|
||||
Pradeep Ramani
|
||||
Cris Cecka
|
||||
Vijay Thakkar
|
||||
Aniket Shivam
|
||||
Honghao Lu
|
||||
Ethan Yan
|
||||
Zhaodong Chen
|
||||
Jack Kosaian
|
||||
Yujia Zhai
|
||||
Naila Farooqui
|
||||
Piotr Majcher
|
||||
Paul Springer
|
||||
Jin Wang
|
||||
Chinmay Talegaonkar
|
||||
Shang Zhang
|
||||
Scott Yokim
|
||||
Markus Hohnerbach
|
||||
Aditya Atluri
|
||||
David Tanner
|
||||
Manikandan Ananth
|
||||
Vijay Thakkar<br />
|
||||
Pradeep Ramani<br />
|
||||
Cris Cecka<br />
|
||||
Aniket Shivam<br />
|
||||
Jack Kosaian<br />
|
||||
Mark Hoemmen<br />
|
||||
Honghao Lu<br />
|
||||
Ethan Yan<br />
|
||||
Haicheng Wu<br />
|
||||
Andrew Kerr<br />
|
||||
Dustyn Blasig<br />
|
||||
Fengqi Qiao<br />
|
||||
Duane Merrill<br />
|
||||
Yujia Zhai<br />
|
||||
Shang Zhang<br />
|
||||
Piotr Majcher<br />
|
||||
Paul Springer<br />
|
||||
Markus Hohnerbach<br />
|
||||
Jin Wang<br />
|
||||
Aditya Atluri<br />
|
||||
|
||||
## CuTe
|
||||
Cris Cecka<br />
|
||||
Vijay Thakkar<br />
|
||||
|
||||
## CUTLASS Product Manager
|
||||
Matthew Nicely
|
||||
|
||||
Matthew Nicely<br />
|
||||
|
||||
## Former CUTLASS Developers
|
||||
Manish Gupta<br />
|
||||
Naila Farooqui<br />
|
||||
David Tanner<br />
|
||||
Manikandan Ananth<br />
|
||||
Zhaodong Chen<br />
|
||||
Chinmay Talegaonkar<br />
|
||||
|
||||
## CONTRIBUTORS
|
||||
Timothy Costa
|
||||
Julien Demouth
|
||||
Brian Fahs
|
||||
Michael Goldfarb
|
||||
Mostafa Hagog
|
||||
Fei Hu
|
||||
Alan Kaatz
|
||||
Tina Li
|
||||
Timmy Liu
|
||||
Duane Merrill
|
||||
Kevin Siu
|
||||
Markus Tavenrath
|
||||
John Tran
|
||||
Vicki Wang
|
||||
Junkai Wu
|
||||
Fung Xie
|
||||
Albert Xu
|
||||
Jack Yang
|
||||
Xiuxia Zhang
|
||||
Nick Zhao
|
||||
Timothy Costa<br />
|
||||
Julien Demouth<br />
|
||||
Brian Fahs<br />
|
||||
Michael Garland<br />
|
||||
Michael Goldfarb<br />
|
||||
Mostafa Hagog<br />
|
||||
Fei Hu<br />
|
||||
Alan Kaatz<br />
|
||||
Tina Li<br />
|
||||
Timmy Liu<br />
|
||||
Wei Liu<br />
|
||||
Duane Merrill<br />
|
||||
Kevin Siu<br />
|
||||
Markus Tavenrath<br />
|
||||
John Tran<br />
|
||||
Vicki Wang<br />
|
||||
Junkai Wu<br />
|
||||
Fung Xie<br />
|
||||
Albert Xu<br />
|
||||
Yang Xu<br />
|
||||
Jack Yang<br />
|
||||
Scott Yokim<br />
|
||||
Xiuxia Zhang<br />
|
||||
Nick Zhao<br />
|
||||
|
||||
## ACKNOWLEDGEMENTS
|
||||
|
||||
Girish Bharambe
|
||||
Luke Durant
|
||||
Olivier Giroux
|
||||
Stephen Jones
|
||||
Rishkul Kulkarni
|
||||
Bryce Lelbach
|
||||
Joel McCormack
|
||||
Kyrylo Perelygin
|
||||
Girish Bharambe<br />
|
||||
Luke Durant<br />
|
||||
Carter Edwards<br />
|
||||
Olivier Giroux<br />
|
||||
Stephen Jones<br />
|
||||
Rishkul Kulkarni<br />
|
||||
Bryce Lelbach<br />
|
||||
Joel McCormack<br />
|
||||
Kyrylo Perelygin<br />
|
||||
Sean Treichler<br />
|
||||
|
||||
180
README.md
180
README.md
@ -1,18 +1,18 @@
|
||||

|
||||
|
||||
# CUTLASS 2.11
|
||||
# CUTLASS 3.0
|
||||
|
||||
_CUTLASS 2.11 - November 2022_
|
||||
_CUTLASS 3.0 - January 2023_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) and related computations at all levels
|
||||
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
|
||||
and scales within CUDA. It incorporates strategies for hierarchical decomposition and
|
||||
data movement similar to those used to implement cuBLAS and cuDNN. CUTLASS decomposes
|
||||
these "moving parts" into reusable, modular software components abstracted by C++ template
|
||||
classes. These thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
|
||||
and tuned via custom tiling sizes, data types, and other algorithmic policy. The
|
||||
resulting flexibility simplifies their use as building blocks within custom kernels
|
||||
and applications.
|
||||
classes. Primitives for different levels of a conceptual parallelization hierarchy
|
||||
can be specialized and tuned via custom tiling sizes, data types,
|
||||
and other algorithmic policy. The resulting flexibility simplifies their use
|
||||
as building blocks within custom kernels and applications.
|
||||
|
||||
To support a wide variety of applications, CUTLASS provides extensive support for
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
@ -21,60 +21,75 @@ point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
|
||||
single-precision floating point (FP32),
|
||||
[FP32 emulation via tensor core instruction](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
double-precision floating
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, and Ampere architectures.
|
||||
|
||||
CUTLASS implements high-performance Convolution via the implicit GEMM algorithm.
|
||||
Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of
|
||||
CUTLASS's modular GEMM pipeline.
|
||||
This allows CUTLASS to build convolutions by reusing highly optimized warp-wide GEMM components and below.
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, Ampere, and Hopper architectures.
|
||||
|
||||
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.11
|
||||
CUTLASS 3.0 introduces a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations.
|
||||
|
||||
CUTLASS 2.11 is an update to CUTLASS adding:
|
||||
- [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
- [Fused multi-head attention kernel](/examples/41_fused_multi_head_attention). It has two variants: one for fixed sequence lengths, and another for variable sequence lengths.
|
||||
- [Dual GEMM](/examples/45_dual_gemm). It can run two GEMMs that share the same left input matrix in one kernel.
|
||||
- Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
- [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
|
||||
- [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm).
|
||||
- [Optimized Group Conv](/examples/42_ampere_tensorop_group_conv).
|
||||
- [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop).
|
||||
- [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM.
|
||||
- [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115).
|
||||
- Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
|
||||
- **Deprecation announcement:** CUTLASS plans to deprecate the following in the next major release:
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
- CUDA 10.2
|
||||
- C++ 11
|
||||
- **Future requirement announcement:** CUTLASS plans to add the following requirements in the next major release:
|
||||
- Minimum C++ standard - C++17
|
||||
The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
|
||||
|
||||
CUTLASS 3.0 adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
|
||||
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md).
|
||||
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
|
||||
# What's New in CUTLASS 3.0
|
||||
|
||||
CUTLASS 3.0, as the next major version of the CUTLASS API, brings with it CuTe, a new programming model and backend designed for massively parallel heterogenous agents. Using CuTe, CUTLASS 3.0 provides implementations of GEMM kernels for the NVIDIA Hopper architecture.
|
||||
|
||||
- [CuTe-based layouts and layout algebra](/media/docs/cute/00_quickstart.md)
|
||||
- [A new GEMM template API](/media/docs/gemm_api_3x.md) that eschews the architecture-centric hierarchy of 2.x in favour of a new conceptual framing. Read more in the [3.0 design documentation](/media/docs/cutlass_3x_design.md).
|
||||
- Support for 4th generation Hopper Tensor Core instructions (WGMMA) through CuTe.
|
||||
- Support for Hopper asynchronous Tensor Memory Accelerator (TMA) instructions and associated transaction barriers through CuTe.
|
||||
- New warp-specialized GEMM kernels targeting Hopper TMA + WGMMA for speed-of-light GEMMs.
|
||||
- New warp-specialized persistent GEMM kernels targeting Hopper TMA + WGMMA.
|
||||
- Support for CUDA Threadblock Clusters and programmatic TMA multicast for greater execution and data locality.
|
||||
- A new way to instantiate default GEMM kernels using `CollectiveBuilder`s that supersede the 2.x `DefaultXConfiguration` types in favour a metaprogramming based kernel generator functionality. See [example 49](/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu).
|
||||
- Extensions to the CUTLASS library and profiler to support CUTLASS 3.0 Hopper kernels, and a new format
|
||||
for kernel procedural names.
|
||||
- *Announcement*: CUTLASS plans to rename the GitHub branch `master` to `main` with a future release.
|
||||
|
||||
## New architecture, compiler, and CUDA Toolkit requirements
|
||||
|
||||
Minimum requirements:
|
||||
|
||||
- Architecture: Volta
|
||||
- Compiler: Must support at least C++17
|
||||
- CUDA Toolkit version: 11.4
|
||||
|
||||
CUTLASS 3.0 *removes support* for the following:
|
||||
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
- CUDA 10.2
|
||||
- C++ language versions less than 17.
|
||||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.**
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=/media/images/cutlass-2.8-gemm-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.0-gemm-peak-performance.png></p>
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit performance comparable to cuBLAS for scalar GEMM
|
||||
they exhibit peak performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/),
|
||||
an [NVIDIA A2](https://www.nvidia.com/en-us/data-center/products/a2/),
|
||||
an [NVIDIA TitanV](https://www.nvidia.com/en-us/titan/titan-v/),
|
||||
and an [NVIDIA GeForce 2080 Ti](https://www.nvidia.com/en-us/geforce/graphics-cards/rtx-2080-ti/)
|
||||
compiled with the [CUDA 11.5 Toolkit](https://developer.nvidia.com/cuda-downloads). Tensor Core operations are implemented using CUDA's
|
||||
for large matrix dimensions on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture),
|
||||
an [NVIDIA L40](https://www.nvidia.com/en-us/data-center/l40/) (NVIDIA Ada architecture),
|
||||
an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) (NVIDIA Ampere architecture),
|
||||
and an [NVIDIA A40](https://www.nvidia.com/en-us/data-center/a40/) (NVIDIA Ampere architecture).
|
||||
CUTLASS 3.0 was compiled with the [CUDA 12.0 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
Tensor Core operations are implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
|
||||
<p align="center"><img src=/media/images/cutlass-2.9-implicit-gemm-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-2.9-implicit-gemm-performance.png></p>
|
||||
|
||||
When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad)
|
||||
kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/)
|
||||
@ -83,39 +98,48 @@ as shown in the above figure. Tensor Core operations are still implemented usin
|
||||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++11 host compiler and performs best when built with the [**CUDA 11.8 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
|
||||
|
||||
It is also compatible with CUDA 11.x.
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
performs best when built with the [**CUDA 12.0 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, and CUDA 11.8.
|
||||
|
||||
## Operating Systems
|
||||
We have tested the following environments.
|
||||
|
||||
|**Operating System** | **Compiler** |
|
||||
|-----------------|----------|
|
||||
| Windows 10 | Microsoft Visual Studio 2015|
|
||||
| | Microsoft Visual Studio 2017|
|
||||
| | Microsoft Visual Studio 2019|
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
| Ubuntu 22.04 | GCC 11.2.0 |
|
||||
|
||||
Additionally, CUTLASS may be built with clang.
|
||||
See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
Note: We plan to add Windows (MSVC) & Clang compiler support soon.
|
||||
|
||||
## Hardware
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs.
|
||||
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**Minimum CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|---|---|---|---|
|
||||
|NVIDIA Tesla V100|7.0|9.2|10.1|
|
||||
|NVIDIA TitanV|7.0|9.2|10.1|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|
||||
|NVIDIA Tesla T4|7.5|10.0|10.2|
|
||||
|NVIDIA A100|8.0|11.0|11.0|
|
||||
|NVIDIA A10 |8.6|11.1|11.1|
|
||||
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|
||||
|NVIDIA H100 PCIe|9.0|11.8|Double-precision: 11.8; Mixed precision: 12.0|
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit Required by CUTLASS-3**|
|
||||
|---|---|---|
|
||||
|NVIDIA V100 Tensor Core GPU |7.0|11.4|
|
||||
|NVIDIA TitanV |7.0|11.4|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070 |7.5|11.4|
|
||||
|NVIDIA T4 |7.5|11.4|
|
||||
|NVIDIA A100 Tensor Core GPU |8.0|11.4|
|
||||
|NVIDIA A10 |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 3090 |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 4090 |8.9|11.8|
|
||||
|NVIDIA L40 |8.9|11.8|
|
||||
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|
||||
|
||||
## Target Architecture
|
||||
|
||||
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduces the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12.0 or 11.8, the kernel is expected to fail with a runtime error.
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
```
|
||||
|
||||
Please refer to the [functionality documentation](media/docs/functionality.md) for details on which kernels require which target architectures.
|
||||
|
||||
# Documentation
|
||||
|
||||
@ -125,7 +149,9 @@ CUTLASS is described in the following documents and the accompanying
|
||||
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
|
||||
- [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [GEMM API](media/docs/gemm_api.md) - describes the CUTLASS GEMM model and C++ template concepts
|
||||
- [CUTLASS 3.x Design](media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](media/docs/terminology.md) - describes terms used in the code
|
||||
@ -161,7 +187,8 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
|
||||
```
|
||||
|
||||
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, and 8.6. To reduce compile time you can specify
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6, 8.9, and 9.0.
|
||||
To reduce compile time you can specify
|
||||
the architectures to build CUTLASS for by changing the CMake configuration setting
|
||||
`CUTLASS_NVCC_ARCHS`.
|
||||
|
||||
@ -224,6 +251,23 @@ include/ # client applications should target this directory
|
||||
transform/ # code specialized for layout, type, and domain transformations
|
||||
|
||||
* # core vocabulary types, containers, and basic numeric operations
|
||||
|
||||
cute/ # CuTe Layout, layout algebra, MMA/Copy atoms, tiled MMA/Copy
|
||||
|
||||
algorithm/ # Definitions of core operations such as copy, gemm, and operations on cute::tuples
|
||||
|
||||
arch/ # Bare bones PTX wrapper structs for copy and math instructions
|
||||
|
||||
atom/ # Meta-information either link to or built from arch/ operators
|
||||
|
||||
mma_atom.hpp # cute::Mma_Atom and cute::TiledMma
|
||||
|
||||
copy_atom.hpp # cute::Copy_Atom and cute::TiledCopy
|
||||
|
||||
*sm*.hpp # Arch specific meta-information for copy and math operations
|
||||
|
||||
* # Core library types such as Shape, Stride, Layout, Tensor, and associated operations
|
||||
|
||||
```
|
||||
|
||||
### CUTLASS SDK Examples
|
||||
@ -269,7 +313,7 @@ By default, only one tile size is instantiated for each data type, math instruct
|
||||
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
|
||||
Beware, this results in *thousands* of kernels and long build times.
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=all
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
@ -40,7 +40,7 @@ elseif(NOT TARGET cublas)
|
||||
|
||||
find_path(
|
||||
_CUBLAS_INCLUDE_DIR
|
||||
NAMES cublas.h
|
||||
NAMES cublas_v2.h
|
||||
HINTS
|
||||
${CUBLAS_INCLUDE_PATH}
|
||||
ENV CUBLAS_INCLUDE_PATH
|
||||
|
||||
@ -45,5 +45,6 @@ target_link_libraries(
|
||||
PRIVATE
|
||||
cutlass_lib
|
||||
cutlass_tools_util_includes
|
||||
cuda
|
||||
)
|
||||
|
||||
|
||||
@ -45,5 +45,6 @@ target_link_libraries(
|
||||
PRIVATE
|
||||
cutlass_lib
|
||||
cutlass_tools_util_includes
|
||||
cuda
|
||||
)
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@
|
||||
GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm)
|
||||
+ lightweight full reduction kernel (ApplyFinalReduction)
|
||||
+ GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion)
|
||||
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@ -77,7 +77,7 @@ template <
|
||||
typename ElementLayernormCompute_,
|
||||
typename ElementOutput,
|
||||
typename ThreadblockShape_,
|
||||
bool IsShiftedVariance_ = false
|
||||
bool IsShiftedVariance_ = false
|
||||
>
|
||||
class ApplyFinalReduction {
|
||||
public:
|
||||
@ -91,7 +91,7 @@ public:
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
using TensorVariance = TensorRef<ElementVariance, Layout>;
|
||||
using TensorMean = TensorRef<ElementMean, Layout>;
|
||||
using TensorMean = TensorRef<ElementMean, Layout>;
|
||||
|
||||
static bool const kIsShiftedVariance = IsShiftedVariance_;
|
||||
|
||||
@ -463,7 +463,7 @@ public:
|
||||
for (int rid = 0; rid < kRowIterations; ++rid) {
|
||||
int row_step_offset = rid * kDeltaRow;
|
||||
int row_offset = thread_offset_row_base + step_offset + row_step_offset;
|
||||
bool is_load = (row_offset < extent_.row());
|
||||
bool is_load = (row_offset < extent_.row());
|
||||
shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load);
|
||||
}
|
||||
|
||||
@ -504,9 +504,9 @@ public:
|
||||
using Minus = cutlass::minus<ElementLayernormCompute>;
|
||||
using Exp = cutlass::fast_exp_op<ElementLayernormCompute>;
|
||||
|
||||
Minus minus;
|
||||
Mul mul;
|
||||
Exp exponential;
|
||||
[[maybe_unused]] Minus minus;
|
||||
[[maybe_unused]] Mul mul;
|
||||
[[maybe_unused]] Exp exponential;
|
||||
|
||||
LayernormFragment result;
|
||||
|
||||
@ -605,7 +605,7 @@ private:
|
||||
CUTLASS_DEVICE
|
||||
ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) {
|
||||
using ConvertShiftK = cutlass::NumericConverter<ElementLayernormCompute, ElementOutput>;
|
||||
ConvertShiftK convert_shift_k;
|
||||
ConvertShiftK convert_shift_k;
|
||||
ElementOutput shift_k_val;
|
||||
|
||||
// Computes the address to load shift_k element
|
||||
@ -614,7 +614,7 @@ private:
|
||||
arch::global_load<ElementOutput, sizeof(ElementOutput)>(shift_k_val, (void *)curr_ptr_shift_k, is_load);
|
||||
// Converts data type to return
|
||||
ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val);
|
||||
|
||||
|
||||
return converted_shift_k_val;
|
||||
}
|
||||
|
||||
@ -689,7 +689,7 @@ public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
|
||||
static bool const kInternalTranspose = cutlass::platform::is_same<LayoutOutput_, cutlass::layout::ColumnMajor>::value;
|
||||
static bool const kIsShiftedVariance = IsShiftedVariance_;
|
||||
|
||||
@ -704,14 +704,14 @@ public:
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using ArchTag = cutlass::arch::Sm80;
|
||||
|
||||
// These are mandatory layouts and data types
|
||||
// These are mandatory layouts and data types
|
||||
// that are inheritated from pre-defined params
|
||||
|
||||
|
||||
using LayoutSumSqr = LayoutInputScaleBias;
|
||||
using LayoutSum = LayoutInputScaleBias;
|
||||
|
||||
using ElementMean = ElementInputScaleBias;
|
||||
using ElementVariance = ElementInputScaleBias;
|
||||
using ElementVariance = ElementInputScaleBias;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -720,7 +720,7 @@ public:
|
||||
using LayoutInputA1 = LayoutOutput_;
|
||||
using LayoutInputB1 = LayoutOutput_;
|
||||
using LayoutOutputC0 = LayoutOutput_;
|
||||
using LayoutOutputC1 = LayoutOutput_;
|
||||
using LayoutOutputC1 = LayoutOutput_;
|
||||
|
||||
using ElementInputA0 = ElementInputA0_;
|
||||
using ElementInputB0 = ElementInputB0_;
|
||||
@ -747,7 +747,7 @@ public:
|
||||
static int const kStages1 = Stages1;
|
||||
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using MapArguments = cutlass::gemm::kernel::detail::MapArguments<
|
||||
|
||||
@ -180,9 +180,9 @@ public:
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
Arguments():
|
||||
problem_count(0),
|
||||
threadblock_count(0),
|
||||
threadblock_count(0),
|
||||
ptr_Q(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_P(nullptr),
|
||||
@ -201,7 +201,7 @@ public:
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
Arguments(
|
||||
GemmCoord *problem_sizes0,
|
||||
GemmCoord *problem_sizes1,
|
||||
int problem_count,
|
||||
@ -219,7 +219,7 @@ public:
|
||||
typename LayoutO::Stride::LongIndex *ldo,
|
||||
bool causal,
|
||||
GemmCoord *host_problem_sizes=nullptr
|
||||
):
|
||||
):
|
||||
problem_sizes0(problem_sizes0),
|
||||
problem_sizes1(problem_sizes1),
|
||||
problem_count(problem_count),
|
||||
@ -311,7 +311,7 @@ public:
|
||||
ldv(args.ldv),
|
||||
ldo(args.ldo),
|
||||
causal(args.causal)
|
||||
{
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
@ -464,7 +464,7 @@ public:
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
auto& m_prime = shared_storage.m_prime;
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
auto& si = shared_storage.after_mm0.si;
|
||||
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
|
||||
auto& mi = shared_storage.mi;
|
||||
|
||||
ProblemVisitor problem_visitor(
|
||||
|
||||
@ -481,7 +481,7 @@ struct AttentionKernel {
|
||||
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
|
||||
auto& m_prime = shared_storage.m_prime;
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
auto& si = shared_storage.after_mm0.si;
|
||||
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
|
||||
auto& mi = shared_storage.mi;
|
||||
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
|
||||
@ -384,7 +384,7 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
// but not supported as it worsens perf: older gpus < sm80 don't
|
||||
// support async tranfers and have to waste registers
|
||||
CUTLASS_DEVICE
|
||||
bool set_prologue_done(bool value) {}
|
||||
void set_prologue_done(bool value) {}
|
||||
CUTLASS_DEVICE
|
||||
static void prologue(
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
@ -695,7 +695,7 @@ class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool set_prologue_done(bool value) {
|
||||
void set_prologue_done(bool value) {
|
||||
prologue_done_ = value;
|
||||
}
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@
|
||||
"classic data-parallel" and "Split-K" decompositions.
|
||||
|
||||
For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition
|
||||
for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598)
|
||||
for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598)
|
||||
|
||||
Requires NVIDIA Ampere or newer device (SM80+).
|
||||
|
||||
|
||||
@ -0,0 +1,463 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Simple Hopper GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
|
||||
This example demonstrate a simple way to instantiate and run a TF32 GEMM using the new CUTLASS 3.0
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster. Another advantage is that TMA can load in FP32 data and
|
||||
convert them implicitly to TF32.
|
||||
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
|
||||
Examples:
|
||||
|
||||
$ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = float; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = float; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TilesShape = Shape<_128,_128,_32>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TilesShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>>;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(5120), n(4096), k(4096),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(1000)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "48_hopper_warp_specialized_gemm\n\n"
|
||||
<< " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "48_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{}));
|
||||
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{}));
|
||||
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{}));
|
||||
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{}));
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
block_C.reset(options.m * options.n);
|
||||
block_D.reset(options.m * options.n);
|
||||
block_ref_D.reset(options.m * options.n);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k},
|
||||
block_A.get(),
|
||||
stride_A,
|
||||
block_B.get(),
|
||||
stride_B,
|
||||
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.n, options.k}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{options.m, options.n, options.k},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
35
examples/48_hopper_warp_specialized_gemm/CMakeLists.txt
Normal file
35
examples/48_hopper_warp_specialized_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,35 @@
|
||||
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
48_hopper_warp_specialized_gemm
|
||||
48_hopper_warp_specialized_gemm.cu
|
||||
)
|
||||
@ -0,0 +1,522 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper GEMM example leveraging collective operation builders.
|
||||
|
||||
This example showcases the use of CUTLASS's CollectiveBuilder to easily construct performant kernels
|
||||
targetting the NVIDIA Hopper architecture.
|
||||
|
||||
Background and motivation
|
||||
-------------------------
|
||||
CUTLASS kernels are highly parameterizable via template parameters. To ease the selection of template
|
||||
parameters, CUTLASS 2 leveraged DefaultGemmConfigurations. Given a small set of parameters, such as
|
||||
the data types of operands and the compute capability of the GPU, DefaultGemmConfigurations defined sensible
|
||||
defaults for the many other parameters to the kernel (e.g., warp shape, stage count).
|
||||
|
||||
However, DefaultGemmConfigurations leave multiple opportunities for improvement, which are addressed
|
||||
in CUTLASS 3:
|
||||
(1) DefaultGemmConfigurations do not allow one to use a more-performant set of parameters without
|
||||
specifying every parameter. For example, the DefaultGemmConfigurations for GEMMs targetting
|
||||
Ampere specify that three pipeline stages should be used regardless of the sizes of operands.
|
||||
If one wished to increase this value, one would also need to specify all other template parameters.
|
||||
This leaves a gap between a high-level ease-of-use interface and a lower-level detailed interface.
|
||||
(2) A new DefaultGemmConfiguration was required for each combination of operand types, GPU architecture,
|
||||
and operation type (e.g., Tensor Core or SIMT). This led to increased code size to cover each unique
|
||||
configuration and a lack of extensibility from one DefaultGemmConfiguration to another.
|
||||
|
||||
Alongside these opportunities for improvement, the Hopper architecture offers new features that increase
|
||||
the number of valid configurations of a kernel. In addition to the many template parameters already available
|
||||
in CUTLASS 2 kernels, CUTLASS 3 kernels targetting Hopper also have various scheduling modes to select from that control:
|
||||
(1) how data is to be loaded (e.g., using the Hopper TMA feature or Ampere cp.async)
|
||||
(2) how work is to be divided among warps in a thread block (e.g., whether to use "warp specialization")
|
||||
(3) whether persistent thread blocks should be used
|
||||
This increased configuration space further motivates rethinking DefaultGemmConfigurations.
|
||||
|
||||
Introduction to the CollectiveBuilder
|
||||
-------------------------------------
|
||||
CUTLASS 3 introduces the CollectiveBuilder to further ease the process of selecting template parameters
|
||||
for kernels targetting Hopper. Similar to the DefaultGemmConfigurations used in CUTLASS 2, the CollectiveBuilder
|
||||
takes in a small set of template parameters (e.g., the data types of operands A and B). It then automatically
|
||||
determines the data loading strategy to use depending on whether the Hopper TMA feature can be used with the provided
|
||||
parameters. If one does not indicate a particular scheduling policy or stage count to use (by using `Auto` template
|
||||
parameters), the CollectiveBuilder will also automatically select these.
|
||||
|
||||
Unlike DefaultGemmConfigurations a parital specialization of the CollectiveBuilder is not needed for many
|
||||
configurations of operand types. Instead the CollectiveBuilder "builds" a configuration based on generic
|
||||
properties of the specified operands, layouts, and other parameters. For example, when the stage count
|
||||
is set to `Auto`, the CollectiveBuilder may automatically calculate the maximum number of stages that
|
||||
will fit in shared memory given the types of operands and the thread block shape, rather than simply using
|
||||
a single default value.
|
||||
|
||||
Note that one does not need to use the CollectiveBuilder to declare CUTLASS 3 kernels; one can still provide
|
||||
every template parameter to the gemm::collective::CollectiveMma. Specifying every template parameter in this
|
||||
manner remains the primary API for using CUTLASS 3 kernels. The CollectiveBuilder is simply meant to be
|
||||
a convenience interface.
|
||||
|
||||
Note also that, while the selections made by CollectiveBuilder attempt to maximize performance, this is not
|
||||
a guarantee. Furthermore, the behavior of the CollectiveBuilder when `Auto` parameters are provided is subject
|
||||
to change in future CUTLASS releases -- do not rely on `Auto` if you require a specific scheduling policy and/or
|
||||
stage count to be used.
|
||||
|
||||
Details of this example
|
||||
-----------------------
|
||||
This example walks through the use of the CollectiveBuilder with various schedules and stage counts specified.
|
||||
This example also illustrates how CUTLASS 3 GEMMs targetting Hopper automatically support batched GEMMs by simply
|
||||
extending the problem size with an additional tensor rank.
|
||||
|
||||
Example usage:
|
||||
$ ./examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder \
|
||||
--m=2048 --n=2048 --k=2048 --l=2
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
|
||||
int m, n, k, l;
|
||||
float alpha, beta;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
m(2048), n(2048), k(2048), l(1),
|
||||
alpha(1.f), beta(0.f)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m, 2048);
|
||||
cmd.get_cmd_line_argument("n", n, 2048);
|
||||
cmd.get_cmd_line_argument("k", k, 2048);
|
||||
cmd.get_cmd_line_argument("l", l, 1);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "49_hopper_gemm_schedules_with_collective_builder\n\n"
|
||||
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
|
||||
<< " performant kernels targetting NVIDIA's Hopper architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective
|
||||
// operation builders by specializing the GEMM only on the kernel schedule it will use and the
|
||||
// number of pipeline stages.
|
||||
//
|
||||
// For either option, one can use a special `Auto` type that tells the CollectiveBuilder
|
||||
// to select an appropriate value on its own. The CollectiveBuilder will attempt to select
|
||||
// values that will result in the most-performant kernel, but this is not a guarantee. Furthermore,
|
||||
// the behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases
|
||||
// -- do not rely on `Auto` if you require a specific scheduling policy.
|
||||
template <
|
||||
// Type of kernel schedule to generate
|
||||
class KernelScheduleType = cutlass::gemm::collective::KernelScheduleAuto,
|
||||
// Number of pipeline stages to use
|
||||
class StageCountType = cutlass::gemm::collective::StageCountAuto
|
||||
>
|
||||
struct ExampleRunner {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using LayoutD = cutlass::layout::ColumnMajor;
|
||||
|
||||
static constexpr int kAlignmentA = 8;
|
||||
static constexpr int kAlignmentB = 8;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, kAlignmentA,
|
||||
cutlass::half_t, LayoutB, kAlignmentB,
|
||||
float,
|
||||
Shape<_128,_128,_64>, Shape<_2,_1,_1>,
|
||||
StageCountType,
|
||||
KernelScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::gemm::TagToStrideC_t<LayoutD>,
|
||||
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 1, float, float>>;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideA>());
|
||||
using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B<StrideB>());
|
||||
using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideC>());
|
||||
using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideD>());
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed = 0;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(const ProblemShapeType& problem_size, float alpha, float beta) {
|
||||
auto [M, N, K, L] = problem_size;
|
||||
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N}));
|
||||
|
||||
cutlass::reference::device::GemmComplex(
|
||||
{M, N, K},
|
||||
typename Gemm::EpilogueOutputOp::ElementCompute(alpha),
|
||||
ref_A,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ref_B,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
typename Gemm::EpilogueOutputOp::ElementCompute(beta),
|
||||
ref_C,
|
||||
ref_D,
|
||||
typename Gemm::EpilogueOutputOp::ElementAccumulator(0.f),
|
||||
L, // batch_count
|
||||
M * K, // batch_stride_A
|
||||
K * N, // batch_stride_B
|
||||
M * N, // batch_stride_C
|
||||
M * N // batch_stride_D
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const ProblemShapeType& problem_size) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
block_A.reset(M * K * L);
|
||||
block_B.reset(K * N * L);
|
||||
block_C.reset(M * N * L);
|
||||
block_D.reset(M * N * L);
|
||||
block_ref_D.reset(M * N * L);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
|
||||
|
||||
initialize(problem_size);
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
block_A.get(),
|
||||
stride_A,
|
||||
block_B.get(),
|
||||
stride_B,
|
||||
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}},
|
||||
hw_info
|
||||
};
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run the GEMM
|
||||
status = gemm_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = verify(problem_size, options.alpha, options.beta);
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, bool passed) {
|
||||
std::cout << description << ": " << (passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
bool passed;
|
||||
|
||||
// This first example constructs a GEMM using the default schedule and stage count provided by
|
||||
// the CollectiveBuilder. The scheduling policy that is expected to be most performant will be
|
||||
// selected and the maximum number of stages that can fit in shared memory will be selected.
|
||||
//
|
||||
// This example is equivalent to declaring
|
||||
// ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto, cutlass::gemm::collective::StageCountAuto>
|
||||
// Each of the `Auto` types indicate that the CollectiveBuilder should determine the scheduling policy and
|
||||
// stage count. Note that the behavior of the CollectiveBuilder with `Auto` parameters is subject to change
|
||||
// -- do not rely on `Auto` if you require a specific scheduling policy.
|
||||
ExampleRunner<> auto_schedule_auto_stage_runner;
|
||||
passed = auto_schedule_auto_stage_runner.run(options, hw_info);
|
||||
print_result("Automatically-selected schedule and stage count", passed);
|
||||
|
||||
// One can override the stage count used in the GEMM by replacing cutlass::gemm::collective::StageCountAuto
|
||||
// with the number of stages to use (5 in this case).
|
||||
ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto, _5> auto_schedule_5_stage_runner;
|
||||
passed = auto_schedule_5_stage_runner.run(options, hw_info);
|
||||
print_result("Automatically-selected schedule with 5 stages", passed);
|
||||
|
||||
// One can also override the scheduling policy to use. In this case, use the KernelTma scheduling
|
||||
// policy, which specifies that the Hopper TMA feature should be used.
|
||||
ExampleRunner<cutlass::gemm::KernelTma> tma_schedule_auto_stage_runner;
|
||||
passed = tma_schedule_auto_stage_runner.run(options, hw_info);
|
||||
print_result("TMA schedule with automatically-selected stage count", passed);
|
||||
|
||||
// Here, we override the scheduling policy to use Hopper's TMA feature alongside the warp-specialized
|
||||
// scheduling policy.
|
||||
//
|
||||
// Note that, as of the CUTLASS 3.0 release, this is the default scheduling policy
|
||||
// used by the CollectiveBuilder, so this declaration is equivalent to ExampleRunner<> and
|
||||
// ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto>. However, this default is subject to
|
||||
// change in future releases -- do not rely on `Auto` if you require a specific scheduling policy.
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized> ws_schedule_auto_stage_runner;
|
||||
passed = ws_schedule_auto_stage_runner.run(options, hw_info);
|
||||
print_result("Warp-specialized TMA schedule with automatically-selected stage count", passed);
|
||||
|
||||
// Finally, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized
|
||||
// scheduling policy, leveraging persistent thread blocks.
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecializedPersistent> ws_persistent_schedule_auto_stage_runner;
|
||||
passed = ws_persistent_schedule_auto_stage_runner.run(options, hw_info);
|
||||
print_result("Persistent warp-specialized TMA schedule with automatically-selected stage count", passed);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,35 @@
|
||||
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
49_hopper_gemm_schedules_with_collective_builder
|
||||
49_hopper_gemm_schedules_with_collective_builder.cu
|
||||
)
|
||||
@ -0,0 +1,529 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Hopper GEMM example to create a GEMM kernel with custom Collectives
|
||||
|
||||
The following example shows how to assemble a custom GEMM kernel that spells out the Collectives
|
||||
directly instead of using a builder and, in the process, instance a more efficient Epilogue
|
||||
(from `cutlass/epilogue/collective/epilogue.hpp`) instead of using the default epilogue.
|
||||
|
||||
The GemmUniversal API takes 3 main template arguments:
|
||||
(1) the problem shape / extents
|
||||
(2) the collective mainloop type
|
||||
(3) the collective epilogue type
|
||||
|
||||
While the collecive mainloop can be stamped out using a CollectiveBuilder interface, it is
|
||||
possible to build a custom collective mainloop directly as well. Furthermore, since epilogues
|
||||
do not yet have a builder interface, this example shows how to instantiate a more-efficient
|
||||
epilogue alongside the collective mainloop.
|
||||
|
||||
Note: there are several ways to implement the GEMM epilogue in Hopper - each with its own set
|
||||
of trade-offs. So it is recommended that users look at the options available under
|
||||
cutlass/epilogue/collective and evaluate for their particular scenario.
|
||||
|
||||
Please refer to examples 48, 49 to learn more about kernel schedules and other CuTe examples
|
||||
present in `test/unit/cute` to famialiarize with the basics of CuTe.
|
||||
|
||||
Examples:
|
||||
|
||||
$ ./examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
|
||||
int m, n, k, l;
|
||||
int alpha, beta;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
m(2048), n(2048), k(2048), l(1),
|
||||
alpha(1), beta(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m, 2048);
|
||||
cmd.get_cmd_line_argument("n", n, 2048);
|
||||
cmd.get_cmd_line_argument("k", k, 2048);
|
||||
cmd.get_cmd_line_argument("l", l, 1);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "50_hopper_gemm_with_vectorized_epilogue\n\n"
|
||||
<< "Hopper GEMM Example with Epilogue Swizzle.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
|
||||
<< " --alpha=<s32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<s32> Epilogue scalar beta\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
// Wrapper to run and verify a GEMM.
|
||||
template <
|
||||
class Gemm
|
||||
>
|
||||
struct ExampleRunner {
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
using LayoutA = typename Gemm::LayoutA;
|
||||
using LayoutB = typename Gemm::LayoutB;
|
||||
using LayoutC = typename Gemm::LayoutC;
|
||||
using LayoutD = typename Gemm::LayoutD;
|
||||
|
||||
using ElementA = typename Gemm::ElementA;
|
||||
using ElementB = typename Gemm::ElementB;
|
||||
using ElementAcc = typename Gemm::ElementAccumulator;
|
||||
|
||||
using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
|
||||
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
|
||||
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
|
||||
|
||||
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed = 0;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<ElementOutput> block_ref_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(const ProblemShapeType& problem_size, int32_t alpha, int32_t beta) {
|
||||
auto [M, N, K, L] = problem_size;
|
||||
|
||||
cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N}));
|
||||
|
||||
cutlass::reference::device::GemmComplex(
|
||||
{M, N, K},
|
||||
ElementCompute(alpha),
|
||||
ref_A,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ref_B,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ElementCompute(beta),
|
||||
ref_C,
|
||||
ref_D,
|
||||
ElementAccumulator(0),
|
||||
L, // batch_count
|
||||
M * K, // batch_stride_A
|
||||
K * N, // batch_stride_B
|
||||
M * N, // batch_stride_C
|
||||
M * N // batch_stride_D
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const ProblemShapeType& problem_size) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
block_A.reset(M * K * L);
|
||||
block_B.reset(K * N * L);
|
||||
block_C.reset(M * N * L);
|
||||
block_D.reset(M * N * L);
|
||||
block_ref_D.reset(M * N * L);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
|
||||
|
||||
initialize(problem_size);
|
||||
|
||||
typename Gemm::GemmKernel::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
block_A.get(),
|
||||
stride_A,
|
||||
block_B.get(),
|
||||
stride_B,
|
||||
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}},
|
||||
hw_info
|
||||
};
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run the GEMM
|
||||
status = gemm_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = verify(problem_size, options.alpha, options.beta);
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
bool passed;
|
||||
|
||||
// Problem configuration
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementAcc = int32_t;
|
||||
using ElementOutput = int8_t;
|
||||
|
||||
// Note : Only TN WGMMA Gemm is supported currently in 3.0
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using LayoutD = cutlass::layout::ColumnMajor;
|
||||
|
||||
// Tiling configuration selection
|
||||
using TileShape = Shape<_128,_64,_128>;
|
||||
|
||||
// Choosing a thread block cluster larger than 1 allows us to Multicast data across thread blocks
|
||||
using ClusterShape = Shape<_1,_2,_1>;
|
||||
|
||||
//
|
||||
// Assembling the CollectiveMainloop type
|
||||
//
|
||||
|
||||
// Pipeline Depth to be used i.e number of A, B buffers in shared memory
|
||||
constexpr int PipelineStages = 8;
|
||||
|
||||
// Let's choose a Warp-Specialized Mainloop implemention which uses TMA
|
||||
// Note : This requires / assumes the tensors to be 16B aligned
|
||||
using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>;
|
||||
|
||||
// TN => K Major for both A & B
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = cute::GMMA::Major::K;
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = cute::GMMA::Major::K;
|
||||
|
||||
// We use the SS op selector as both A, B operands are read directly from SMEM (for TN WGMMA)
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementA, ElementB, ElementAcc, TileShape, GmmaMajorA, GmmaMajorB>()));
|
||||
|
||||
// A loads can be optimized with multicast if cluster-n > 1
|
||||
using GmemTiledCopyA = std::conditional< cute::size(shape<1>(ClusterShape{})) == 1,
|
||||
cute::SM90_TMA_LOAD,
|
||||
cute::SM90_TMA_LOAD_MULTICAST>::type;
|
||||
|
||||
// B loads can be optimized with multicast if cluster-m > 1
|
||||
using GmemTiledCopyB = std::conditional< cute::size(shape<0>(ClusterShape{})) == 1,
|
||||
cute::SM90_TMA_LOAD,
|
||||
cute::SM90_TMA_LOAD_MULTICAST>::type;
|
||||
|
||||
using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector<
|
||||
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape{})), decltype(cute::get<2>(TileShape{}))
|
||||
>());
|
||||
|
||||
using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape{})), decltype(cute::get<2>(TileShape{}))
|
||||
>());
|
||||
|
||||
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape,
|
||||
ElementA,
|
||||
cutlass::gemm::TagToStrideA_t<LayoutA>,
|
||||
ElementB,
|
||||
cutlass::gemm::TagToStrideB_t<LayoutB>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
void, // Does not need a SmemCopyAtom, since A is read directly from SMEM
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
void, // Does not need a SmemCopyAtom, since B is read directly from SMEM
|
||||
cute::identity
|
||||
>;
|
||||
|
||||
//
|
||||
// Assembling the Collective Epilogue Type
|
||||
//
|
||||
|
||||
// Break the 128 along TILE_M into chunks of 32, to get a 128B leading dimension
|
||||
using PreSwizzleLayout = Layout< Shape< Shape <_32,_4 >,_64>,
|
||||
Stride<Stride< _1,_2048>,_32>>;
|
||||
|
||||
// 128 threads loading 16 elements each (to get vectorized global stores)
|
||||
using TileShapeS2R = Shape<_128,_16>;
|
||||
|
||||
// Layout to ensure bank-conflict free loads & stores
|
||||
using SmemLayout = ComposedLayout<
|
||||
Swizzle<3,4,3>,
|
||||
smem_ptr_flag_bits<sizeof_bits<ElementAcc>::value>,
|
||||
PreSwizzleLayout>;
|
||||
|
||||
// Tiled copy from Smem to Registers
|
||||
// Note : CuTe will vectorize this copy if the tiling + swizzling above were right
|
||||
using TiledCopyS2R = TiledCopy<
|
||||
Copy_Atom<DefaultCopy, ElementAcc>,
|
||||
Layout< Shape<_128,_16>,
|
||||
Stride<_16,_1>>,
|
||||
TileShapeS2R>;
|
||||
|
||||
using Epilogue = cutlass::epilogue::collective::Epilogue<
|
||||
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
||||
cutlass::gemm::TagToStrideC_t<LayoutD>,
|
||||
cutlass::epilogue::thread::LinearCombination<int32_t, 1, int32_t, int32_t>,
|
||||
SmemLayout,
|
||||
Copy_Atom<DefaultCopy, ElementAcc>,
|
||||
TiledCopyS2R,
|
||||
Copy_Atom<DefaultCopy, ElementOutput>>;
|
||||
|
||||
//
|
||||
// Assembling the GemmKernel
|
||||
//
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
Epilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
ExampleRunner<Gemm> runner;
|
||||
|
||||
passed = runner.run(options, hw_info);
|
||||
|
||||
std::cout << "WGMMA GEMM with Epilogue Swizzle : " << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
35
examples/50_hopper_gemm_with_epilogue_swizzle/CMakeLists.txt
Normal file
35
examples/50_hopper_gemm_with_epilogue_swizzle/CMakeLists.txt
Normal file
@ -0,0 +1,35 @@
|
||||
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
50_hopper_gemm_with_epilogue_swizzle
|
||||
50_hopper_gemm_with_epilogue_swizzle.cu
|
||||
)
|
||||
@ -54,12 +54,14 @@ function(cutlass_example_add_executable NAME)
|
||||
CUTLASS
|
||||
cutlass_tools_util_includes
|
||||
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:nvidia::cublas>
|
||||
cuda
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
${NAME}
|
||||
PRIVATE
|
||||
${CUTLASS_EXAMPLES_COMMON_SOURCE_DIR}
|
||||
${CUTLASS_EXAMPLES_UTILS_DIR}
|
||||
)
|
||||
|
||||
install(
|
||||
@ -118,6 +120,7 @@ foreach(EXAMPLE
|
||||
36_gather_scatter_fusion
|
||||
37_gemm_layernorm_gemm_fusion
|
||||
38_syr2k_grouped
|
||||
cute
|
||||
39_gemm_permute
|
||||
41_fused_multi_head_attention
|
||||
42_ampere_tensorop_group_conv
|
||||
@ -125,6 +128,9 @@ foreach(EXAMPLE
|
||||
45_dual_gemm
|
||||
46_depthwise_simt_conv2dfprop
|
||||
47_ampere_gemm_universal_streamk
|
||||
48_hopper_warp_specialized_gemm
|
||||
49_hopper_gemm_schedules_with_collective_builder
|
||||
50_hopper_gemm_with_epilogue_swizzle
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
30
examples/cute/CMakeLists.txt
Normal file
30
examples/cute/CMakeLists.txt
Normal file
@ -0,0 +1,30 @@
|
||||
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
add_subdirectory(tutorial)
|
||||
34
examples/cute/tutorial/CMakeLists.txt
Normal file
34
examples/cute/tutorial/CMakeLists.txt
Normal file
@ -0,0 +1,34 @@
|
||||
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
sgemm_nt_1
|
||||
sgemm_nt_1.cu
|
||||
)
|
||||
|
||||
426
examples/cute/tutorial/sgemm_nt_1.cu
Normal file
426
examples/cute/tutorial/sgemm_nt_1.cu
Normal file
@ -0,0 +1,426 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
|
||||
# include "cutlass/util/cublas_wrappers.hpp"
|
||||
#endif
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
template <class MShape, class NShape, class KShape,
|
||||
class TA, class AStride, class ABlockLayout, class AThreadLayout,
|
||||
class TB, class BStride, class BBlockLayout, class BThreadLayout,
|
||||
class TC, class CStride, class CBlockLayout, class CThreadLayout,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
|
||||
void
|
||||
gemm_device(MShape M, NShape N, KShape K,
|
||||
TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
|
||||
TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
||||
TC * C, CStride dC, CBlockLayout , CThreadLayout tC,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(tA) == size(tC));
|
||||
CUTE_STATIC_ASSERT_V(size(tB) == size(tC));
|
||||
|
||||
//CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M
|
||||
//CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K
|
||||
|
||||
// Shared memory buffers
|
||||
__shared__ TA smemA[cosize_v<ABlockLayout>];
|
||||
__shared__ TB smemB[cosize_v<BBlockLayout>];
|
||||
auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K)
|
||||
auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K)
|
||||
|
||||
// Represent the full tensors
|
||||
auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K)
|
||||
auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K)
|
||||
auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N)
|
||||
|
||||
// Get the appropriate blocks for this thread block --
|
||||
// potential for thread block locality
|
||||
auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K)
|
||||
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
||||
|
||||
auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
||||
auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
||||
auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
|
||||
|
||||
//
|
||||
// Partition the copying of A and B tiles across the threads
|
||||
//
|
||||
|
||||
// TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB
|
||||
// Default is a raked partition, but can be changed with Step<X,Y> parameter
|
||||
|
||||
auto tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k)
|
||||
auto tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K)
|
||||
|
||||
auto tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k)
|
||||
auto tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
// TUTORIAL: Example of partitioning via projections of tC
|
||||
|
||||
// Partition sA (M,K) by the rows of tC
|
||||
auto tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)
|
||||
// Partition sB (N,K) by the cols of tC
|
||||
auto tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K)
|
||||
// Partition gC (M,N) by the tile of tC
|
||||
auto tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N)
|
||||
|
||||
// Allocate the accumulators -- same size as the projected data
|
||||
auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N)
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print("mA\n");
|
||||
print(mA.shape()); print("\n"); print(mA.stride());
|
||||
print("\n\ngA\n");
|
||||
print(gA.shape()); print("\n"); print(gA.stride());
|
||||
print("\n\ntAgA\n");
|
||||
print(tAgA.shape()); print("\n"); print(tAgA.stride());
|
||||
print("\n\nsA\n");
|
||||
print(sA.shape()); print("\n"); print(sA.stride());
|
||||
print("\n\ntAsA\n");
|
||||
print(tAsA.shape()); print("\n"); print(tAsA.stride());
|
||||
print("\n\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print("mB\n");
|
||||
print(mB.shape()); print("\n"); print(mB.stride());
|
||||
print("\n\ngB\n");
|
||||
print(gB.shape()); print("\n"); print(gB.stride());
|
||||
print("\n\ntBgB\n");
|
||||
print(tBgB.shape()); print("\n"); print(tBgB.stride());
|
||||
print("\n\nsB\n");
|
||||
print(sB.shape()); print("\n"); print(sB.stride());
|
||||
print("\n\ntBsB\n");
|
||||
print(tBsB.shape()); print("\n"); print(tBsB.stride());
|
||||
print("\n\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print("mC\n");
|
||||
print(mC.shape()); print("\n"); print(mC.stride());
|
||||
print("\n\ngC\n");
|
||||
print(gC.shape()); print("\n"); print(gC.stride());
|
||||
print("\n\ntCsA\n");
|
||||
print(tCsA.shape()); print("\n"); print(tCsA.stride());
|
||||
print("\n\ntCsB\n");
|
||||
print(tCsB.shape()); print("\n"); print(tCsB.stride());
|
||||
print("\n\ntCgC\n");
|
||||
print(tCgC.shape()); print("\n"); print(tCgC.stride());
|
||||
print("\n\ntCrC\n");
|
||||
print(tCrC.shape()); print("\n"); print(tCrC.stride());
|
||||
print("\n\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
|
||||
// TUTORIAL: Example of a very simple compute loop
|
||||
// Data is read from global to shared memory via the tA|tB partitioning
|
||||
// gemm(.) operates on the shared memory directly via the tC partitioning
|
||||
|
||||
auto k_max = size<2>(tAgA);
|
||||
|
||||
for (int k = 0; k < k_max; ++k)
|
||||
{
|
||||
// Copy gmem to smem
|
||||
copy(tAgA(_,_,k), tAsA);
|
||||
copy(tBgB(_,_,k), tBsB);
|
||||
|
||||
// In case copy uses cp.async, make sure that the cp.async
|
||||
// instructions are ordered with respect to other cp.async
|
||||
// instructions (fence), then wait on all the outstanding copy
|
||||
// operations (wait<0>()). __syncthreads() alone does not do
|
||||
// this.
|
||||
//
|
||||
// NOTE: cp_async_wait<0>() currently issues cp.async.wait_all.
|
||||
// This is equivalent to cp.async.commit_group followed by
|
||||
// cp.async_wait_group 0. This should make the first
|
||||
// cp_async_fence() (which also issues cp.async.commit_group)
|
||||
// redundant. The tutorial works as-is, so we'll leave the
|
||||
// redundant fence in for now and study its removal later.
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute gemm on smem
|
||||
gemm(tCsA, tCsB, tCrC);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
axpby(alpha, tCrC, beta, tCgC);
|
||||
}
|
||||
|
||||
|
||||
template <typename TA, typename TB, typename TC,
|
||||
typename Alpha, typename Beta>
|
||||
void
|
||||
gemm(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
|
||||
// Define strides (mixed)
|
||||
auto dA = make_stride(Int<1>{}, ldA);
|
||||
auto dB = make_stride(Int<1>{}, ldB);
|
||||
auto dC = make_stride(Int<1>{}, ldC);
|
||||
|
||||
// Define block sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
|
||||
// Define the block layouts (static)
|
||||
auto sA = make_layout(make_shape(bM,bK));
|
||||
auto sB = make_layout(make_shape(bN,bK));
|
||||
auto sC = make_layout(make_shape(bM,bN));
|
||||
|
||||
// Define the thread layouts (static)
|
||||
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}));
|
||||
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));
|
||||
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
|
||||
|
||||
dim3 dimBlock(size(tC));
|
||||
dim3 dimGrid(ceil_div(size(M), size(bM)),
|
||||
ceil_div(size(N), size(bN)));
|
||||
gemm_device
|
||||
<<< dimGrid, dimBlock, 0, stream >>>
|
||||
(M, N, K,
|
||||
A, dA, sA, tA,
|
||||
B, dB, sB, tB,
|
||||
C, dC, sC, tC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <cassert>
|
||||
|
||||
void test_gemm(int m, int n, int k)
|
||||
{
|
||||
cute::device_init(0);
|
||||
|
||||
std::cout << "M = " << m << std::endl;
|
||||
std::cout << "N = " << n << std::endl;
|
||||
std::cout << "K = " << k << std::endl;
|
||||
|
||||
using TA = float;
|
||||
using TB = float;
|
||||
using TC = float;
|
||||
using TI = float;
|
||||
|
||||
thrust::host_vector<TA> h_A(m*k);
|
||||
thrust::host_vector<TB> h_B(n*k);
|
||||
thrust::host_vector<TC> h_C(m*n);
|
||||
|
||||
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
|
||||
|
||||
thrust::device_vector<TA> d_A = h_A;
|
||||
thrust::device_vector<TB> d_B = h_B;
|
||||
thrust::device_vector<TC> d_C = h_C;
|
||||
|
||||
TI alpha = 1.0;
|
||||
TI beta = 0.0;
|
||||
|
||||
double gflops = (2.0*m*n*k) * 1e-9;
|
||||
|
||||
const int timing_iterations = 100;
|
||||
GPU_Clock timer;
|
||||
|
||||
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
|
||||
//
|
||||
// cuBLas
|
||||
//
|
||||
|
||||
cublasHandle_t handle;
|
||||
cublasCreate(&handle);
|
||||
|
||||
// Run once
|
||||
d_C = h_C;
|
||||
blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T,
|
||||
m, n, k,
|
||||
&alpha,
|
||||
d_A.data().get(), m,
|
||||
d_B.data().get(), n,
|
||||
&beta,
|
||||
d_C.data().get(), m);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
thrust::host_vector<TC> cublas_result = d_C;
|
||||
|
||||
// Timing iterations
|
||||
timer.start();
|
||||
for (int i = 0; i < timing_iterations; ++i) {
|
||||
blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T,
|
||||
m, n, k,
|
||||
&alpha,
|
||||
d_A.data().get(), m,
|
||||
d_B.data().get(), n,
|
||||
&beta,
|
||||
d_C.data().get(), m);
|
||||
}
|
||||
double cublas_time = timer.seconds() / timing_iterations;
|
||||
CUTE_CHECK_LAST();
|
||||
printf("CUBLAS_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cublas_time, cublas_time*1000);
|
||||
|
||||
#else
|
||||
|
||||
std::cout << "Verification by comparison with cuBLAS is disabled, "
|
||||
"either because the CMake option CUTLASS_ENABLE_CUBLAS "
|
||||
"was explicitly set to OFF, or because CMake could not find cuBLAS. "
|
||||
"If you would like to enable verification with cuBLAS, "
|
||||
"please set the CMake option CUTLASS_ENABLE_CUBLAS to ON, "
|
||||
"rerun CMake, and recompile this example.\n";
|
||||
|
||||
#endif // CUTLASS_ENABLE_CUBLAS
|
||||
|
||||
//
|
||||
// CuTe
|
||||
//
|
||||
|
||||
// Run once (and check)
|
||||
d_C = h_C;
|
||||
gemm(m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), m,
|
||||
d_B.data().get(), n,
|
||||
beta,
|
||||
d_C.data().get(), m);
|
||||
CUTE_CHECK_LAST();
|
||||
thrust::host_vector<TC> cute_result = d_C;
|
||||
|
||||
// Timing iterations
|
||||
timer.start();
|
||||
for (int i = 0; i < timing_iterations; ++i) {
|
||||
gemm(m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), m,
|
||||
d_B.data().get(), n,
|
||||
beta,
|
||||
d_C.data().get(), m);
|
||||
}
|
||||
double cute_time = timer.seconds() / timing_iterations;
|
||||
CUTE_CHECK_LAST();
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
|
||||
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
|
||||
printf("Empirical Perf: %.1f%%\n", (cublas_time / cute_time) * 100);
|
||||
|
||||
auto host_matrix_to_const_column_major_cute_tensor =
|
||||
[](const auto& X, int num_rows, int num_cols, int LDX) {
|
||||
const auto shape = cute::Shape<int, int>{num_rows, num_cols};
|
||||
const auto strides = cute::Stride<int, int>{1, LDX};
|
||||
return cute::make_tensor(X.data(), cute::make_layout(shape, strides));
|
||||
};
|
||||
|
||||
const auto A_view = host_matrix_to_const_column_major_cute_tensor(h_A, m, k, m);
|
||||
// B^T is k x n, so B is n x k.
|
||||
const auto B_view = host_matrix_to_const_column_major_cute_tensor(h_B, n, k, n);
|
||||
const auto C_computed_view = host_matrix_to_const_column_major_cute_tensor(cute_result, m, n, m);
|
||||
const auto C_expected_view = host_matrix_to_const_column_major_cute_tensor(cublas_result, m, n, m);
|
||||
print_matrix_multiply_mollified_relative_error("float", A_view, B_view, C_computed_view, C_expected_view);
|
||||
|
||||
#endif // CUTLASS_ENABLE_CUBLAS
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
int m = 5120;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &m);
|
||||
|
||||
int n = 5120;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &n);
|
||||
|
||||
int k = 4096;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &k);
|
||||
|
||||
test_gemm(m, n, k);
|
||||
|
||||
return 0;
|
||||
}
|
||||
79
include/cute/algorithm/axpby.hpp
Normal file
79
include/cute/algorithm/axpby.hpp
Normal file
@ -0,0 +1,79 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
template <class Alpha,
|
||||
class XEngine, class XLayout,
|
||||
class Beta,
|
||||
class YEngine, class YLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
axpby(Alpha const& alpha,
|
||||
Tensor<XEngine, XLayout> const& x,
|
||||
Beta const& beta,
|
||||
Tensor<YEngine, YLayout> && y)
|
||||
{
|
||||
return axpby(alpha, x, beta, y);
|
||||
}
|
||||
|
||||
//
|
||||
// AXPBY
|
||||
//
|
||||
template <class Alpha,
|
||||
class XEngine, class XLayout,
|
||||
class Beta,
|
||||
class YEngine, class YLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
axpby(Alpha const& alpha,
|
||||
Tensor<XEngine, XLayout> const& x,
|
||||
Beta const& beta,
|
||||
Tensor<YEngine, YLayout> & y)
|
||||
{
|
||||
auto isBetaZero = (beta == Int<0>{});
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(x); ++i) {
|
||||
y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i));
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
66
include/cute/algorithm/clear.hpp
Normal file
66
include/cute/algorithm/clear.hpp
Normal file
@ -0,0 +1,66 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cute/algorithm/fill.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
clear(Tensor<Engine, Layout>&& tensor)
|
||||
{
|
||||
return clear(tensor);
|
||||
}
|
||||
|
||||
//
|
||||
// Set elements to zero
|
||||
//
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
clear(Tensor<Engine, Layout>& tensor)
|
||||
{
|
||||
using T = typename Tensor<Engine,Layout>::value_type;
|
||||
|
||||
fill(tensor, T{});
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
262
include/cute/algorithm/copy.hpp
Normal file
262
include/cute/algorithm/copy.hpp
Normal file
@ -0,0 +1,262 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_predicate.hpp>
|
||||
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
|
||||
template <class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_if(pred, src, dst);
|
||||
}
|
||||
|
||||
template <class... CopyArgs,
|
||||
class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_if(copy_atom, pred, src, dst);
|
||||
}
|
||||
|
||||
template <class VecType,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy_vec<VecType>(src, dst);
|
||||
}
|
||||
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy(src, dst);
|
||||
}
|
||||
|
||||
template <class... CopyArgs,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> && dst)
|
||||
{
|
||||
return copy(copy_atom, src, dst);
|
||||
}
|
||||
|
||||
//
|
||||
// copy_if -- Predicated Copy
|
||||
//
|
||||
|
||||
template <class PrdTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(PrdTensor const& pred,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
auto copy_op = select_elementwise_copy(src, dst);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(src); ++i) {
|
||||
if (pred(i)) {
|
||||
copy_op.copy(src(i), dst(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// copy_if -- Predicated CopyAtom
|
||||
//
|
||||
|
||||
template <class... CopyArgs,
|
||||
class PredTensor,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
PredTensor const& pred, // (Rest...)
|
||||
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
|
||||
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
|
||||
{
|
||||
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
|
||||
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
|
||||
copy_atom.call(src, dst);
|
||||
} else { // Loop over all but the first mode
|
||||
constexpr int R = SrcLayout::rank;
|
||||
auto src_v = group_modes<1,R>(src);
|
||||
auto dst_v = group_modes<1,R>(dst);
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<1>(src_v); ++i) {
|
||||
if (pred(i)) {
|
||||
copy_atom.call(src_v(_,i), dst_v(_,i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// copy_vec -- attempt vectorized copy with VecType
|
||||
//
|
||||
|
||||
template <class VecType,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
using SrcType = typename SrcEngine::value_type;
|
||||
using DstType = typename DstEngine::value_type;
|
||||
if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType))
|
||||
{
|
||||
/* @pre is_aligned<N>(src.data()) &&
|
||||
* is_aligned<N>(dst.data())
|
||||
*/
|
||||
auto src_v = recast<VecType const>(src);
|
||||
auto dst_v = recast<VecType >(dst);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy_vec -- vectorizing copy from %3db to %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(VecType)));
|
||||
print(" "); print(layout(src)); print(" => "); print(layout(src_v)); print("\n");
|
||||
print(" "); print(layout(dst)); print(" => "); print(layout(dst_v)); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
return copy_if(TrivialPredTensor{}, src_v, dst_v);
|
||||
} else {
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy_vec -- not vectorizing, copy with %3db and %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(DstType)));
|
||||
print(" "); print(layout(src)); print("\n");
|
||||
print(" "); print(layout(dst)); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
return copy_if(TrivialPredTensor{}, src, dst);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// copy -- auto-vectorizing copy
|
||||
//
|
||||
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
constexpr int N = decltype(max_common_vector(src, dst))::value;
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy -- found a max_common_vector of %d\n", N);
|
||||
print(" "); print(src.data()); print(" o "); print(layout(src)); print("\n");
|
||||
print(" "); print(dst.data()); print(" o "); print(layout(dst)); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr (N <= 1) {
|
||||
return copy_if(TrivialPredTensor{}, src, dst);
|
||||
} else {
|
||||
constexpr int vec_bits = N * sizeof_bits<typename SrcEngine::value_type>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
return copy_vec<VecType>(src, dst);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// copy -- CopyAtom
|
||||
//
|
||||
|
||||
template <class... CopyArgs,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
return copy_if(copy_atom, TrivialPredTensor{}, src, dst);
|
||||
}
|
||||
|
||||
template <class... CopyArgs,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Copy_Atom<DefaultCopy, CopyArgs...> const&,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
return copy(src, dst);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
87
include/cute/algorithm/fill.hpp
Normal file
87
include/cute/algorithm/fill.hpp
Normal file
@ -0,0 +1,87 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/algorithm/prefer.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
template <class Engine, class Layout, class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
fill(Tensor<Engine, Layout>&& tensor, T const& value)
|
||||
{
|
||||
return fill(tensor, value);
|
||||
}
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
// Prefer fill(tensor.data(), value), if possible
|
||||
template <class Engine, class Layout, class T>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
fill(Tensor<Engine, Layout>& tensor, T const& value, prefer<1>)
|
||||
-> decltype(fill(tensor.data(), value))
|
||||
{
|
||||
fill(tensor.data(), value);
|
||||
}
|
||||
|
||||
// Default implementation
|
||||
template <class Engine, class Layout, class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
fill(Tensor<Engine, Layout>& tensor, T const& value, prefer<0>)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = value;
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class Engine, class Layout, class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
fill(Tensor<Engine, Layout>& tensor, T const& value)
|
||||
{
|
||||
return detail::fill(tensor, value, prefer<1>{});
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
198
include/cute/algorithm/functional.hpp
Normal file
198
include/cute/algorithm/functional.hpp
Normal file
@ -0,0 +1,198 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
/** C++14 <functional> extensions */
|
||||
|
||||
namespace cute {
|
||||
|
||||
/**************/
|
||||
/** Identity **/
|
||||
/**************/
|
||||
|
||||
struct identity {
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator()(T&& arg) const {
|
||||
return std::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
template <class R>
|
||||
struct constant_fn {
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) operator()(T&&...) const {
|
||||
return r_;
|
||||
}
|
||||
R r_;
|
||||
};
|
||||
|
||||
/***********/
|
||||
/** Unary **/
|
||||
/***********/
|
||||
|
||||
#define CUTE_LEFT_UNARY_OP(NAME,OP) \
|
||||
struct NAME { \
|
||||
template <class T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& arg) const { \
|
||||
return OP std::forward<T>(arg); \
|
||||
} \
|
||||
}
|
||||
#define CUTE_RIGHT_UNARY_OP(NAME,OP) \
|
||||
struct NAME { \
|
||||
template <class T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& arg) const { \
|
||||
return std::forward<T>(arg) OP ; \
|
||||
} \
|
||||
}
|
||||
#define CUTE_NAMED_UNARY_OP(NAME,OP) \
|
||||
struct NAME { \
|
||||
template <class T> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& arg) const { \
|
||||
return OP (std::forward<T>(arg)); \
|
||||
} \
|
||||
}
|
||||
|
||||
CUTE_LEFT_UNARY_OP(unary_plus, +);
|
||||
CUTE_LEFT_UNARY_OP(negate, -);
|
||||
CUTE_LEFT_UNARY_OP(bit_not, ~);
|
||||
CUTE_LEFT_UNARY_OP(logical_not, !);
|
||||
CUTE_LEFT_UNARY_OP(dereference, *);
|
||||
CUTE_LEFT_UNARY_OP(address_of, &);
|
||||
CUTE_LEFT_UNARY_OP(pre_increment, ++);
|
||||
CUTE_LEFT_UNARY_OP(pre_decrement, --);
|
||||
|
||||
CUTE_RIGHT_UNARY_OP(post_increment, ++);
|
||||
CUTE_RIGHT_UNARY_OP(post_decrement, --);
|
||||
|
||||
CUTE_NAMED_UNARY_OP(abs_fn, abs);
|
||||
CUTE_NAMED_UNARY_OP(conjugate, cute::conj);
|
||||
|
||||
#undef CUTE_LEFT_UNARY_OP
|
||||
#undef CUTE_RIGHT_UNARY_OP
|
||||
#undef CUTE_NAMED_UNARY_OP
|
||||
|
||||
/************/
|
||||
/** Binary **/
|
||||
/************/
|
||||
|
||||
#define CUTE_BINARY_OP(NAME,OP) \
|
||||
struct NAME { \
|
||||
template <class T, class U> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
|
||||
return std::forward<T>(lhs) OP std::forward<U>(rhs); \
|
||||
} \
|
||||
}
|
||||
#define CUTE_NAMED_BINARY_OP(NAME,OP) \
|
||||
struct NAME { \
|
||||
template <class T, class U> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
|
||||
return OP (std::forward<T>(lhs), std::forward<U>(rhs)); \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
CUTE_BINARY_OP(plus, +);
|
||||
CUTE_BINARY_OP(minus, -);
|
||||
CUTE_BINARY_OP(multiplies, *);
|
||||
CUTE_BINARY_OP(divides, /);
|
||||
CUTE_BINARY_OP(modulus, %);
|
||||
|
||||
CUTE_BINARY_OP(plus_assign, +=);
|
||||
CUTE_BINARY_OP(minus_assign, -=);
|
||||
CUTE_BINARY_OP(multiplies_assign, *=);
|
||||
CUTE_BINARY_OP(divides_assign, /=);
|
||||
CUTE_BINARY_OP(modulus_assign, %=);
|
||||
|
||||
CUTE_BINARY_OP(bit_and, &);
|
||||
CUTE_BINARY_OP(bit_or, |);
|
||||
CUTE_BINARY_OP(bit_xor, ^);
|
||||
CUTE_BINARY_OP(left_shift, <<);
|
||||
CUTE_BINARY_OP(right_shift, >>);
|
||||
|
||||
CUTE_BINARY_OP(bit_and_assign, &=);
|
||||
CUTE_BINARY_OP(bit_or_assign, |=);
|
||||
CUTE_BINARY_OP(bit_xor_assign, ^=);
|
||||
CUTE_BINARY_OP(left_shift_assign, <<=);
|
||||
CUTE_BINARY_OP(right_shift_assign, >>=);
|
||||
|
||||
CUTE_BINARY_OP(logical_and, &&);
|
||||
CUTE_BINARY_OP(logical_or, ||);
|
||||
|
||||
CUTE_BINARY_OP(equal_to, ==);
|
||||
CUTE_BINARY_OP(not_equal_to, !=);
|
||||
CUTE_BINARY_OP(greater, >);
|
||||
CUTE_BINARY_OP(less, <);
|
||||
CUTE_BINARY_OP(greater_equal, >=);
|
||||
CUTE_BINARY_OP(less_equal, <=);
|
||||
|
||||
CUTE_NAMED_BINARY_OP(max_fn, cute::max);
|
||||
CUTE_NAMED_BINARY_OP(min_fn, cute::min);
|
||||
|
||||
#undef CUTE_BINARY_OP
|
||||
#undef CUTE_NAMED_BINARY_OP
|
||||
|
||||
/**********/
|
||||
/** Meta **/
|
||||
/**********/
|
||||
|
||||
template <class Fn, class Arg>
|
||||
struct bound_fn {
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
operator()(T&& arg) {
|
||||
return fn_(arg_, std::forward<T>(arg));
|
||||
}
|
||||
|
||||
Fn fn_;
|
||||
Arg arg_;
|
||||
};
|
||||
|
||||
template <class Fn, class Arg>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
bind(Fn const& fn, Arg const& arg) {
|
||||
return bound_fn<Fn,Arg>{fn, arg};
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
718
include/cute/algorithm/gemm.hpp
Normal file
718
include/cute/algorithm/gemm.hpp
Normal file
@ -0,0 +1,718 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
/** The gemm algorithm takes four (or three) tensors and computes
|
||||
* D += A * B + C
|
||||
* It dispatches based on the number of modes each tensor has:
|
||||
*
|
||||
* 1. `(V) x (V) => (V)`.
|
||||
* The element-wise product of vectors. Dispatches to FMA or MMA.
|
||||
* 2. `(M) x (N) => (M,N)`.
|
||||
* The outer product of vectors. Dispatches to [3] with new mode K=(1).
|
||||
* 3. `(M,K) x (N,K) => (M,N)`.
|
||||
* The product of matrices. Dispatches to [5] with MMA vector-mode V.
|
||||
* 4. `(V,M) x (V,N) => (V,M,N)`.
|
||||
* The batched outer product of vectors. Accounts for register reuse and dispatches to [1] for each (m,n).
|
||||
* 5. `(V,M,K) x (V,N,K) => (V,M,N)`.
|
||||
* The batched product of matrices. Dispatches to [4] for each (k).
|
||||
*/
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Three arguments to four
|
||||
//
|
||||
|
||||
template <class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> & C)
|
||||
{
|
||||
return gemm(C, A, B, C);
|
||||
}
|
||||
|
||||
template <class MMA,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> & C)
|
||||
{
|
||||
return gemm(mma, C, A, B, C);
|
||||
}
|
||||
|
||||
//
|
||||
// Accept mutable temporaries
|
||||
//
|
||||
|
||||
template <class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> && C)
|
||||
{
|
||||
return gemm(C, A, B, C);
|
||||
}
|
||||
|
||||
template <class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(Tensor<TD, DLayout> && D,
|
||||
Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> const& C)
|
||||
{
|
||||
return gemm(D, A, B, C);
|
||||
}
|
||||
|
||||
template <class MMA,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> && C)
|
||||
{
|
||||
return gemm(mma, C, A, B, C);
|
||||
}
|
||||
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> && D,
|
||||
Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> const& C)
|
||||
{
|
||||
return gemm(mma, D, A, B, C);
|
||||
}
|
||||
|
||||
//
|
||||
// Default MMA is UniversalFMA
|
||||
//
|
||||
|
||||
template <class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(Tensor<TD, DLayout> & D,
|
||||
Tensor<TA, ALayout> const& A,
|
||||
Tensor<TB, BLayout> const& B,
|
||||
Tensor<TC, CLayout> const& C)
|
||||
{
|
||||
using MMA = MMA_Atom<UniversalFMA<typename Tensor<TD,DLayout>::value_type,
|
||||
typename Tensor<TA,ALayout>::value_type,
|
||||
typename Tensor<TB,BLayout>::value_type,
|
||||
typename Tensor<TC,CLayout>::value_type>>;
|
||||
|
||||
return gemm(MMA{}, D, A, B, C);
|
||||
}
|
||||
|
||||
//
|
||||
// Thread-Local Register-Memory GEMMs
|
||||
//
|
||||
|
||||
// Dispatch [1]: (V) x (V) => (V)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 1 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 1 && is_rmem<TA>::value &&
|
||||
BLayout::rank == 1 && is_rmem<TB>::value &&
|
||||
CLayout::rank == 1 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (V) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (V) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (V) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (V) Logical data
|
||||
{
|
||||
// No static assertions on (V), MMA checks compatibility
|
||||
mma.call(D, A, B, C);
|
||||
}
|
||||
|
||||
// Dispatch [2]: (M) x (N) => (M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 1 && is_rmem<TA>::value &&
|
||||
BLayout::rank == 1 && is_rmem<TB>::value &&
|
||||
CLayout::rank == 2 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (M) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (N) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
|
||||
|
||||
gemm(mma,
|
||||
D, // (M,N)
|
||||
make_tensor(A.data(), append<2>(A.layout())), // (M,1)
|
||||
make_tensor(B.data(), append<2>(B.layout())), // (N,1)
|
||||
C); // (M,N)
|
||||
}
|
||||
|
||||
// Dispatch [3]: (M,K) x (N,K) => (M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 2 && is_rmem<TA>::value &&
|
||||
BLayout::rank == 2 && is_rmem<TB>::value &&
|
||||
CLayout::rank == 2 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (M,K) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (N,K) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
|
||||
|
||||
// Assert this is a 1-value MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
|
||||
|
||||
gemm(mma,
|
||||
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
|
||||
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
|
||||
make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K)
|
||||
make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N)
|
||||
}
|
||||
|
||||
// Dispatch [4]: (V,M) x (V,N) => (V,M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 2 && is_rmem<TA>::value &&
|
||||
BLayout::rank == 2 && is_rmem<TB>::value &&
|
||||
CLayout::rank == 3 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (V,M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (V,M) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (V,N) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (V,M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
|
||||
|
||||
// REGISTER .reuse OPTIMIZATIONS
|
||||
|
||||
auto M = size<1>(A);
|
||||
auto N = size<1>(B);
|
||||
|
||||
// 64-bit traversal specialization -- serpentine path
|
||||
if (size<0>(A) * sizeof(typename Tensor<TA,ALayout>::value_type) == 8 &&
|
||||
size<0>(B) * sizeof(typename Tensor<TB,BLayout>::value_type) == 8)
|
||||
{
|
||||
#if 1 // NOTE: Must depend on the C-matrix order... (which we can test)
|
||||
// Row-major iteration
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < M; ++m) {
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < N; ++n) {
|
||||
int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate
|
||||
gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns));
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Col-major iteration
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < N; ++n) {
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < M; ++m) {
|
||||
int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate
|
||||
gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else
|
||||
|
||||
// 32-bit traversal specialization -- kinked serpentine path
|
||||
if (size<0>(A) * sizeof(typename Tensor<TA,ALayout>::value_type) == 4 &&
|
||||
size<0>(B) * sizeof(typename Tensor<TB,BLayout>::value_type) == 4)
|
||||
{
|
||||
#if 1 // NOTE: Must depend on the C-matrix order... (which we can test)
|
||||
// Row-major iteration
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < M; m += 2) {
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < N; ++n) {
|
||||
int ns = (m & 2) ? N-1-n : n;
|
||||
gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns));
|
||||
|
||||
if (m+1 < M) {
|
||||
gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns));
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Col-major iteration
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < N; n += 2) {
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < M; ++m) {
|
||||
// Kinked serpentine traversal for maximum register reuse
|
||||
int ms = (n & 2) ? M-1-m : m;
|
||||
gemm(mma, D(_,ms,n+0), A(_,ms), B(_,n+0), C(_,ms,n+0));
|
||||
|
||||
if (n+1 < N) {
|
||||
gemm(mma, D(_,ms,n+1), A(_,ms), B(_,n+1), C(_,ms,n+1));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
// Fallback to serpentine loop
|
||||
// Col-major iteration
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < N; ++n) {
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < M; ++m) {
|
||||
int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate
|
||||
gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 3 && is_rmem<TA>::value &&
|
||||
BLayout::rank == 3 && is_rmem<TB>::value &&
|
||||
CLayout::rank == 3 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (V,M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (V,M,K) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (V,N,K) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (V,M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
|
||||
|
||||
auto K = size<2>(A);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k = 0; k < K; ++k) {
|
||||
gemm(mma, D, A(_,_,k), B(_,_,k), C);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Thread-Local Shared-Memory GEMMs
|
||||
//
|
||||
|
||||
// Dispatch [1]: (V) x (V) => (V)
|
||||
// Dispatch [2]: (M) x (N) => (M,N)
|
||||
// Dispatch [3]: (M,K) x (N,K) => (M,N)
|
||||
// Dispatch [4]: (V,M) x (V,N) => (V,M,N)
|
||||
// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
|
||||
// Dispatch [3]: (M,K) x (N,K) => (M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (M,K) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (N,K) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
|
||||
|
||||
// Assert this is a 1-value MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
|
||||
|
||||
gemm(mma,
|
||||
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
|
||||
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
|
||||
make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K)
|
||||
make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N)
|
||||
}
|
||||
|
||||
// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
|
||||
template <class MMA,
|
||||
class TD, class DLayout,
|
||||
class TA, class ALayout,
|
||||
class TB, class BLayout,
|
||||
class TC, class CLayout,
|
||||
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
|
||||
ALayout::rank == 3 && is_smem<TA>::value &&
|
||||
BLayout::rank == 3 && is_smem<TB>::value &&
|
||||
CLayout::rank == 3 && is_rmem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(MMA_Atom<MMA> const& mma,
|
||||
Tensor<TD, DLayout> & D, // (V,M,N) Logical data
|
||||
Tensor<TA, ALayout> const& A, // (V,M,K) Logical data
|
||||
Tensor<TB, BLayout> const& B, // (V,N,K) Logical data
|
||||
Tensor<TC, CLayout> const& C) // (V,M,N) Logical data
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK
|
||||
CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
|
||||
|
||||
auto rA = MMA_Atom<MMA>::make_fragment_A(A);
|
||||
auto rB = MMA_Atom<MMA>::make_fragment_B(B);
|
||||
|
||||
auto K = size<2>(A);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k = 0; k < K; ++k)
|
||||
{
|
||||
copy(A(_,_,k), rA(_,_,k));
|
||||
copy(B(_,_,k), rB(_,_,k));
|
||||
// Thread-level register gemm for k
|
||||
gemm(mma, D, rA(_,_,k), rB(_,_,k), C);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Collective Shared-Memory GEMMs
|
||||
//
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
class ALoadTransformOp, class BLoadTransformOp,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC,
|
||||
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
|
||||
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
|
||||
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
|
||||
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
|
||||
|
||||
using TypeA = typename TA::value_type;
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
|
||||
static_assert(std::is_same_v<std::decay_t<std::invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
|
||||
"ALoadTransformOp functor must accept and return value of type TA::value_type");
|
||||
static_assert(std::is_same_v<std::decay_t<std::invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
|
||||
"BLoadTransformOp functor must accept and return value of type TB::value_type");
|
||||
|
||||
// Original, static size of the problem
|
||||
auto M = size<0>(sC);
|
||||
auto N = size<1>(sC);
|
||||
auto K = size<1>(sA);
|
||||
|
||||
// Block size of the compute tile
|
||||
auto BLK_M = tile_size<0>(thr_mma);
|
||||
auto BLK_N = tile_size<1>(thr_mma);
|
||||
auto BLK_K = tile_size<2>(thr_mma);
|
||||
|
||||
// Compute the "residues"
|
||||
auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M]
|
||||
auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N]
|
||||
auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0]
|
||||
|
||||
// Shift the origin so k_residue is zeroth tile
|
||||
sA.data() = &sA(0,k_residue);
|
||||
sB.data() = &sB(0,k_residue);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M));
|
||||
printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N));
|
||||
printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K));
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// MMA Partitioning
|
||||
//
|
||||
|
||||
// Round the layout extents up to BLK_X
|
||||
Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K));
|
||||
Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K));
|
||||
Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N));
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(rounded_sA.layout()); print("\n");
|
||||
print(rounded_sB.layout()); print("\n");
|
||||
print(rounded_sC.layout()); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
// Partition the sA and sB tiles across the threads for the MMA
|
||||
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(tCsA.layout()); print("\n");
|
||||
print(tCsB.layout()); print("\n");
|
||||
print(tCsC.layout()); print("\n");
|
||||
print(tCrA.layout()); print("\n");
|
||||
print(tCrB.layout()); print("\n");
|
||||
print(tCrC.layout()); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREDICATION
|
||||
//
|
||||
|
||||
// Allocate the preds for only the MMA-mode of tCsA and tCsB
|
||||
Tensor tCpA = make_tensor<bool>(size<0>(tCsA));
|
||||
Tensor tCpB = make_tensor<bool>(size<0>(tCsB));
|
||||
|
||||
// Create coordinate tensors on a single compute block for predication
|
||||
Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k)
|
||||
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k)
|
||||
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k)
|
||||
|
||||
// Populate the m and n predicates
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpA); ++i) {
|
||||
tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue);
|
||||
}
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpB); ++i) {
|
||||
tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue);
|
||||
}
|
||||
|
||||
#if 0
|
||||
printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n",
|
||||
threadIdx.x,
|
||||
int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)),
|
||||
int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0)));
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREFETCH k_block = 0 (with k-predication)
|
||||
//
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I
|
||||
if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m
|
||||
tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I
|
||||
if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n
|
||||
tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
|
||||
}
|
||||
}
|
||||
}
|
||||
//
|
||||
// MAINLOOP
|
||||
//
|
||||
|
||||
// Clear accumulators
|
||||
clear(tCrC);
|
||||
|
||||
constexpr int K_BLOCK_MAX = size<2>(tCrA);
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
|
||||
{
|
||||
// static-if load the next k_block. No k-predication required on these loads.
|
||||
if (k_block < K_BLOCK_MAX-1)
|
||||
{
|
||||
// Load the next k_block
|
||||
int k_next = k_block + 1;
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m
|
||||
tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n
|
||||
tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GEMM on k_block in registers
|
||||
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n)
|
||||
|
||||
const bool isBetaZero = (beta == Beta{});
|
||||
|
||||
// Custom axpby_if for now
|
||||
CUTE_UNROLL
|
||||
for (int m = 0; m < size<1>(tCsC); ++m)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int n = 0; n < size<2>(tCsC); ++n)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size<0>(tCsC); ++i)
|
||||
{
|
||||
if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) &&
|
||||
(n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue))
|
||||
{
|
||||
tCsC(i,m,n) = isBetaZero ? alpha * tCrC(i,m,n) : alpha * tCrC(i,m,n) + beta * tCsC(i,m,n);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class Alpha, class TA, class ALayout, class TB, class BLayout,
|
||||
class Beta, class TC, class CLayout,
|
||||
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
|
||||
BLayout::rank == 2 && is_smem<TB>::value &&
|
||||
CLayout::rank == 2 && is_smem<TC>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
gemm(ThrMMA<Args...> const& thr_mma,
|
||||
Alpha const& alpha,
|
||||
Tensor<TA, ALayout> sA,
|
||||
Tensor<TB, BLayout> sB,
|
||||
Beta const& beta,
|
||||
Tensor<TC, CLayout> sC)
|
||||
{
|
||||
gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
46
include/cute/algorithm/prefer.hpp
Normal file
46
include/cute/algorithm/prefer.hpp
Normal file
@ -0,0 +1,46 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
// Infinite types that inherit from each other
|
||||
template <std::size_t N>
|
||||
struct prefer : prefer<N-1> {};
|
||||
|
||||
template <>
|
||||
struct prefer<0> {};
|
||||
|
||||
// Can be used to preferencially overload implementations
|
||||
// Higher N in prefer<N> have higher priority.
|
||||
|
||||
} // end namespace cute
|
||||
102
include/cute/algorithm/tensor_algorithms.hpp
Normal file
102
include/cute/algorithm/tensor_algorithms.hpp
Normal file
@ -0,0 +1,102 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/** Common algorithms on (hierarchical) tensors */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// for_each
|
||||
//
|
||||
|
||||
template <class Engine, class Layout, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_each(Tensor<Engine,Layout> const& tensor, UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
static_cast<UnaryOp&&>(op)(tensor(i));
|
||||
}
|
||||
}
|
||||
|
||||
template <class Engine, class Layout, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_each(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
static_cast<UnaryOp&&>(op)(tensor(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class Engine, class Layout, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_each(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
|
||||
{
|
||||
return for_each(tensor, static_cast<UnaryOp&&>(op));
|
||||
}
|
||||
|
||||
//
|
||||
// transform
|
||||
//
|
||||
|
||||
// Similar to std::transform but does not return number of elements affected
|
||||
template <class Engine, class Layout, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = static_cast<UnaryOp&&>(op)(tensor(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class Engine, class Layout, class UnaryOp>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
transform(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
|
||||
{
|
||||
return transform(tensor, std::forward<UnaryOp>(op));
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
846
include/cute/algorithm/tuple_algorithms.hpp
Normal file
846
include/cute/algorithm/tuple_algorithms.hpp
Normal file
@ -0,0 +1,846 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/container/tuple.hpp>
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/numeric/integer_sequence.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
/** Common algorithms on (hierarchical) tuples */
|
||||
/** Style choice:
|
||||
* Forward params [using static_cast<T&&>(.)] for const/non-const/ref/non-ref args
|
||||
* but don't bother forwarding functions as ref-qualified member fns are extremely rare
|
||||
*/
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Apply (Unpack)
|
||||
// (t, f) => f(t_0,t_1,...,t_n)
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F, int... I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
apply(T&& t, F&& f, seq<I...>)
|
||||
{
|
||||
return f(get<I>(static_cast<T&&>(t))...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
apply(T&& t, F&& f)
|
||||
{
|
||||
return detail::apply(static_cast<T&&>(t), f, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Transform Apply
|
||||
// (t, f, g) => g(f(t_0),f(t_1),...)
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tapply(T&& t, F&& f, G&& g, seq<I...>)
|
||||
{
|
||||
return g(f(get<I>(static_cast<T&&>(t)))...);
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq<I...>)
|
||||
{
|
||||
return g(f(get<I>(static_cast<T0&&>(t0)),
|
||||
get<I>(static_cast<T1&&>(t1)))...);
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq<I...>)
|
||||
{
|
||||
return g(f(get<I>(static_cast<T0&&>(t0)),
|
||||
get<I>(static_cast<T1&&>(t1)),
|
||||
get<I>(static_cast<T2&&>(t2)))...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class F, class G>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T&& t, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T&&>(t), f, g, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F, class G>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T0&& t0, T1&& t1, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), f, g, tuple_seq<T0>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F, class G>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g)
|
||||
{
|
||||
return detail::tapply(static_cast<T0&&>(t0), static_cast<T1&&>(t1), static_cast<T2&&>(t2), f, g, tuple_seq<T0>{});
|
||||
}
|
||||
|
||||
//
|
||||
// For Each
|
||||
// (t, f) => f(t_0),f(t_1),...,f(t_n)
|
||||
//
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
for_each(T&& t, F&& f)
|
||||
{
|
||||
detail::apply(t, [&](auto&&... a) { (f(static_cast<decltype(a)&&>(a)), ...); }, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
for_each_leaf(T&& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
|
||||
return detail::apply(static_cast<T&&>(t), [&](auto&&... a){ return (for_each_leaf(static_cast<decltype(a)&&>(a), f), ...); }, tuple_seq<T>{});
|
||||
} else {
|
||||
return f(static_cast<T&&>(t));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Transform
|
||||
// (t, f) => (f(t_0),f(t_1),...,f(t_n))
|
||||
//
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform(T const& t, F&& f)
|
||||
{
|
||||
return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform(T0 const& t0, T1 const& t1, F&& f)
|
||||
{
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
|
||||
return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f)
|
||||
{
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T1>::value, "Mismatched tuple_size");
|
||||
static_assert(tuple_size<T0>::value == tuple_size<T2>::value, "Mismatched tuple_size");
|
||||
return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq<T0>{});
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
transform_leaf(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return transform(t, [&](auto const& a) { return transform_leaf(a, f); });
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// find and find_if
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f, seq<>)
|
||||
{
|
||||
return cute::integral_constant<int, tuple_size<T>::value>{};
|
||||
}
|
||||
|
||||
template <class T, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f, seq<I,Is...>)
|
||||
{
|
||||
if constexpr (decltype(f(get<I>(t)))::value) {
|
||||
return cute::integral_constant<int, I>{};
|
||||
} else {
|
||||
return find_if(t, f, seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find_if(T const& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::find_if(t, f, tuple_seq<T>{});
|
||||
} else {
|
||||
return cute::integral_constant<int, decltype(f(t))::value ? 0 : 1>{};
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
find(T const& t, X const& x)
|
||||
{
|
||||
return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
auto
|
||||
none_of(T const& t, F&& f)
|
||||
{
|
||||
return cute::integral_constant<bool, decltype(find_if(t, f))::value == std::tuple_size<T>::value>{};
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
auto
|
||||
all_of(T const& t, F&& f)
|
||||
{
|
||||
auto not_f = [&](auto const& a) { return !f(a); };
|
||||
return cute::integral_constant<bool, decltype(find_if(t, not_f))::value == std::tuple_size<T>::value>{};
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
auto
|
||||
any_of(T const& t, F&& f)
|
||||
{
|
||||
return cute::integral_constant<bool, !decltype(none_of(t, f))::value>{};
|
||||
}
|
||||
|
||||
//
|
||||
// Filter
|
||||
// (t, f) => <f(t_0),f(t_1),...,f(t_n)>
|
||||
//
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
filter_tuple(T const& t, F&& f)
|
||||
{
|
||||
return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); });
|
||||
}
|
||||
|
||||
template <class T0, class T1, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
filter_tuple(T0 const& t0, T1 const& t1, F&& f)
|
||||
{
|
||||
return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); });
|
||||
}
|
||||
|
||||
//
|
||||
// Fold (Reduce, Accumulate)
|
||||
// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n)
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
// This impl compiles much faster than cute::apply and variadic args
|
||||
template <class T, class V, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
fold(T&& t, V&& v, F&& f, seq<>)
|
||||
{
|
||||
return static_cast<V&&>(v);
|
||||
}
|
||||
|
||||
template <class T, class V, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
fold(T&& t, V&& v, F&& f, seq<I,Is...>)
|
||||
{
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
return f(static_cast<V&&>(v), get<I>(static_cast<T&&>(t)));
|
||||
} else {
|
||||
return fold(static_cast<T&&>(t),
|
||||
f(static_cast<V&&>(v), get<I>(static_cast<T&&>(t))),
|
||||
f,
|
||||
seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class V, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fold(T&& t, V&& v, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
|
||||
return detail::fold(static_cast<T&&>(t),
|
||||
static_cast<V&&>(v),
|
||||
f,
|
||||
tuple_seq<T>{});
|
||||
} else {
|
||||
return f(static_cast<V&&>(v), static_cast<T&&>(t));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
fold_first(T&& t, F&& f)
|
||||
{
|
||||
if constexpr (is_tuple<std::remove_reference_t<T>>::value) {
|
||||
return detail::fold(static_cast<T&&>(t),
|
||||
get<0>(static_cast<T&&>(t)),
|
||||
f,
|
||||
make_range<1,std::tuple_size<std::remove_reference_t<T>>::value>{});
|
||||
} else {
|
||||
return static_cast<T&&>(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// front, back, take, unwrap
|
||||
//
|
||||
|
||||
// Get the first non-tuple element in a hierarchical tuple
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
front(T&& t)
|
||||
{
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
return front(get<0>(static_cast<T&&>(t)));
|
||||
} else {
|
||||
return static_cast<T&&>(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Get the last non-tuple element in a hierarchical tuple
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
back(T&& t)
|
||||
{
|
||||
if constexpr (is_tuple<remove_cvref_t<T>>::value) {
|
||||
constexpr int N = tuple_size<remove_cvref_t<T>>::value;
|
||||
return back(get<N-1>(static_cast<T&&>(t)));
|
||||
} else {
|
||||
return static_cast<T&&>(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Takes the elements in the range [B,E)
|
||||
template <int B, int E, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
take(T const& t)
|
||||
{
|
||||
return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range<B,E>{});
|
||||
}
|
||||
|
||||
// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
unwrap(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (tuple_size<T>::value == 1) {
|
||||
return unwrap(get<0>(t));
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Flatten a hierarchical tuple to a tuple of depth one.
|
||||
//
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flatten_to_tuple(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
} else {
|
||||
return cute::make_tuple(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flatten(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); });
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// insert and remove and replace
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Shortcut around tuple_cat for common insert/remove/repeat cases
|
||||
template <class T, class X, int... I, int... J, int... K>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
construct(T const& t, X const& x, seq<I...>, seq<J...>, seq<K...>)
|
||||
{
|
||||
return cute::make_tuple(get<I>(t)..., (void(J),x)..., get<K>(t)...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Insert x into the Nth position of the tuple
|
||||
template <int N, class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
insert(T const& t, X const& x)
|
||||
{
|
||||
return detail::construct(t, x, make_seq<N>{}, seq<0>{}, make_range<N,tuple_size<T>::value>{});
|
||||
}
|
||||
|
||||
// Remove the Nth element of the tuple
|
||||
template <int N, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
remove(T const& t)
|
||||
{
|
||||
return detail::construct(t, 0, make_seq<N>{}, seq<>{}, make_range<N+1,tuple_size<T>::value>{});
|
||||
}
|
||||
|
||||
// Replace the Nth element of the tuple with x
|
||||
template <int N, class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
replace(T const& t, X const& x)
|
||||
{
|
||||
return detail::construct(t, x, make_seq<N>{}, seq<0>{}, make_range<N+1,tuple_size<T>::value>{});
|
||||
}
|
||||
|
||||
// Replace the first element of the tuple with x
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
replace_front(T const& t, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size<T>::value>{});
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Replace the last element of the tuple with x
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
replace_back(T const& t, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::construct(t, x, make_seq<tuple_size<T>::value-1>{}, seq<0>{}, seq<>{});
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Make a tuple of Xs of tuple_size N
|
||||
//
|
||||
|
||||
template <int N, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
repeat(X const& x)
|
||||
{
|
||||
return detail::construct(0, x, seq<>{}, make_seq<N>{}, seq<>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Make a tuple of Xs the same profile as tuple
|
||||
//
|
||||
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
repeat_like(T const& t, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return transform(t, [&](auto const& a) { return repeat_like(a,x); });
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Group the elements [B,E) of a T into a single element
|
||||
// e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{})
|
||||
// => T<_1,_2,T<_3,_4>,_5,_6>{}
|
||||
template <int B, int E, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
group(T const& t)
|
||||
{
|
||||
return detail::construct(t, take<B,E>(t), make_seq<B>{}, seq<0>{}, make_range<E,tuple_size<T>::value>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Extend a T to rank N by appending/prepending an element
|
||||
//
|
||||
|
||||
template <int N, class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
append(T const& a, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (N == tuple_size<T>::value) {
|
||||
return a;
|
||||
} else {
|
||||
static_assert(N > tuple_size<T>::value);
|
||||
return detail::construct(a, x, make_seq<tuple_size<T>::value>{}, make_seq<N-tuple_size<T>::value>{}, seq<>{});
|
||||
}
|
||||
} else {
|
||||
if constexpr (N == 1) {
|
||||
return a;
|
||||
} else {
|
||||
return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq<N-1>{}, seq<>{});
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
append(T const& a, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::construct(a, x, make_seq<tuple_size<T>::value>{}, seq<0>{}, seq<>{});
|
||||
} else {
|
||||
return cute::make_tuple(a, x);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <int N, class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
prepend(T const& a, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (N == tuple_size<T>::value) {
|
||||
return a;
|
||||
} else {
|
||||
static_assert(N > tuple_size<T>::value);
|
||||
return detail::construct(a, x, seq<>{}, make_seq<N-tuple_size<T>::value>{}, make_seq<tuple_size<T>::value>{});
|
||||
}
|
||||
} else {
|
||||
if constexpr (N == 1) {
|
||||
return a;
|
||||
} else {
|
||||
static_assert(N > 1);
|
||||
return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq<N-1>{}, seq<0>{});
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
template <class T, class X>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
prepend(T const& a, X const& x)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq<tuple_size<T>::value>{});
|
||||
} else {
|
||||
return cute::make_tuple(x, a);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Inclusive scan (prefix sum)
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class V, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
iscan(T const& t, V const& v, F&& f, seq<I,Is...>)
|
||||
{
|
||||
// Apply the function to v and the element at I
|
||||
auto v_next = f(v, get<I>(t));
|
||||
// Replace I with v_next
|
||||
auto t_next = replace<I>(t, v_next);
|
||||
|
||||
#if 0
|
||||
std::cout << "ISCAN i" << I << std::endl;
|
||||
std::cout << " t " << t << std::endl;
|
||||
std::cout << " i " << v << std::endl;
|
||||
std::cout << " f(i,t) " << v_next << std::endl;
|
||||
std::cout << " t_n " << t_next << std::endl;
|
||||
#endif
|
||||
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
return t_next;
|
||||
} else {
|
||||
return iscan(t_next, v_next, f, seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class V, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
iscan(T const& t, V const& v, F&& f)
|
||||
{
|
||||
return detail::iscan(t, v, f, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Exclusive scan (prefix sum)
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class V, class F, int I, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
escan(T const& t, V const& v, F&& f, seq<I,Is...>)
|
||||
{
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
// Replace I with v
|
||||
return replace<I>(t, v);
|
||||
} else {
|
||||
// Apply the function to v and the element at I
|
||||
auto v_next = f(v, get<I>(t));
|
||||
// Replace I with v
|
||||
auto t_next = replace<I>(t, v);
|
||||
|
||||
#if 0
|
||||
std::cout << "ESCAN i" << I << std::endl;
|
||||
std::cout << " t " << t << std::endl;
|
||||
std::cout << " i " << v << std::endl;
|
||||
std::cout << " f(i,t) " << v_next << std::endl;
|
||||
std::cout << " t_n " << t_next << std::endl;
|
||||
#endif
|
||||
|
||||
// Recurse
|
||||
return escan(t_next, v_next, f, seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class V, class F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
escan(T const& t, V const& v, F&& f)
|
||||
{
|
||||
return detail::escan(t, v, f, tuple_seq<T>{});
|
||||
}
|
||||
|
||||
//
|
||||
// Zip (Transpose)
|
||||
//
|
||||
|
||||
// Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input
|
||||
// to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <int J, class T, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip_(T const& t, seq<Is...>)
|
||||
{
|
||||
return cute::make_tuple(get<J>(get<Is>(t))...);
|
||||
}
|
||||
|
||||
template <class T, int... Is, int... Js>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip(T const& t, seq<Is...>, seq<Js...>)
|
||||
{
|
||||
static_assert(conjunction<bool_constant<tuple_size<tuple_element_t<0,T>>::value == tuple_size<tuple_element_t<Is,T>>::value>...>::value, "Mismatched Ranks");
|
||||
return cute::make_tuple(detail::zip_<Js>(t, seq<Is...>{})...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip(T const& t)
|
||||
{
|
||||
if constexpr (is_tuple<T>::value) {
|
||||
if constexpr (is_tuple<tuple_element_t<0,T>>::value) {
|
||||
return detail::zip(t, tuple_seq<T>{}, tuple_seq<tuple_element_t<0,T>>{});
|
||||
} else {
|
||||
return cute::make_tuple(t);
|
||||
}
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Convenient to pass them in separately
|
||||
template <class T0, class T1, class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip(T0 const& t0, T1 const& t1, Ts const&... ts)
|
||||
{
|
||||
return zip(cute::make_tuple(t0, t1, ts...));
|
||||
}
|
||||
|
||||
//
|
||||
// zip2_by -- A guided zip for rank-2 tuples
|
||||
// Take a tuple like ((A,a),((B,b),(C,c)),d)
|
||||
// and produce a tuple ((A,(B,C)),(a,(b,c),d))
|
||||
// where the rank-2 modes are selected by the terminals of the guide (X,(X,X))
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class TG, int... Is, int... Js>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip2_by(T const& t, TG const& guide, seq<Is...>, seq<Js...>)
|
||||
{
|
||||
// zip2_by produces the modes like ((A,a),(B,b),...)
|
||||
auto split = cute::make_tuple(zip2_by(get<Is>(t), get<Is>(guide))...);
|
||||
|
||||
// Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y))
|
||||
return cute::make_tuple(cute::make_tuple(get<Is,0>(split)...),
|
||||
cute::make_tuple(get<Is,1>(split)..., get<Js>(t)...));
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, class TG>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zip2_by(T const& t, TG const& guide)
|
||||
{
|
||||
if constexpr (is_tuple<TG>::value) {
|
||||
constexpr int TR = tuple_size<T>::value;
|
||||
constexpr int GR = tuple_size<TG>::value;
|
||||
static_assert(TR >= GR, "Mismatched ranks");
|
||||
return detail::zip2_by(t, guide,
|
||||
make_range< 0, GR>{},
|
||||
make_range<GR, TR>{});
|
||||
} else {
|
||||
static_assert(tuple_size<T>::value == 2, "Mismatched ranks");
|
||||
return t;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
190
include/cute/arch/cluster_sm90.hpp
Normal file
190
include/cute/arch/cluster_sm90.hpp
Normal file
@ -0,0 +1,190 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \
|
||||
((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))))
|
||||
# define CUTE_ARCH_CLUSTER_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
# define CUTE_ARCH_ELECT_ONE_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute {
|
||||
|
||||
CUTE_DEVICE void cluster_arrive_relaxed()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : );
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_DEVICE void cluster_arrive()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
asm volatile("barrier.cluster.arrive.aligned;\n" : : );
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_DEVICE void cluster_wait()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
asm volatile("barrier.cluster.wait.aligned;\n" : : );
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_DEVICE void cluster_sync()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
cluster_arrive();
|
||||
cluster_wait();
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns the dim3 grid size in terms of number of clusters.
|
||||
CUTE_DEVICE dim3 cluster_grid_dims()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#else
|
||||
return gridDim;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns the dim3 cluster rank in the grid.
|
||||
CUTE_DEVICE dim3 cluster_id_in_grid()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#else
|
||||
return blockIdx;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns the relative dim3 block rank local to the cluster.
|
||||
CUTE_DEVICE dim3 block_id_in_cluster()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#else
|
||||
return {0,0,0};
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns the dim3 cluster shape.
|
||||
CUTE_DEVICE dim3 cluster_shape()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#else
|
||||
return {1,1,1};
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get 1D ctaid in a cluster.
|
||||
CUTLASS_DEVICE uint32_t block_rank_in_cluster()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t rank;
|
||||
asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :);
|
||||
return rank;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Set the destination block-ID in cluster for a given SMEM Address
|
||||
CUTLASS_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank)
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t result;
|
||||
asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n"
|
||||
: "=r"(result)
|
||||
: "r"(smemAddr), "r"(rank));
|
||||
return result;
|
||||
#else
|
||||
return smemAddr;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Elect one thread in the warp. The elected thread gets its predicate set to true, all others obtain false.
|
||||
CUTE_HOST_DEVICE uint32_t elect_one_sync()
|
||||
{
|
||||
#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED)
|
||||
uint32_t pred = 0;
|
||||
uint32_t laneid = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b32 %rx;\n"
|
||||
".reg .pred %px;\n"
|
||||
" elect.sync %rx|%px, %2;\n"
|
||||
"@%px mov.s32 %1, 1;\n"
|
||||
" mov.s32 %0, %rx;\n"
|
||||
"}\n"
|
||||
: "+r"(laneid), "+r"(pred)
|
||||
: "r"(0xFFFFFFFF));
|
||||
return pred;
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
return (threadIdx.x % 32) == 0;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
71
include/cute/arch/copy.hpp
Normal file
71
include/cute/arch/copy.hpp
Normal file
@ -0,0 +1,71 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/util.hpp>
|
||||
#include <cute/numeric/uint128.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Direct Copy for any type
|
||||
//
|
||||
|
||||
template <class S, class D = S>
|
||||
struct UniversalCopy
|
||||
{
|
||||
using SRegisters = S[1];
|
||||
using DRegisters = D[1];
|
||||
|
||||
CUTE_HOST_DEVICE static constexpr void
|
||||
copy(S const& src,
|
||||
D & dst)
|
||||
{
|
||||
dst = src;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Placeholder for the copy algorithm's default, auto-vectorizing behavior
|
||||
//
|
||||
|
||||
struct DefaultCopy
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint128_t[1];
|
||||
};
|
||||
|
||||
using AutoVectorizingCopy = DefaultCopy;
|
||||
|
||||
} // end namespace cute
|
||||
215
include/cute/arch/copy_sm75.hpp
Normal file
215
include/cute/arch/copy_sm75.hpp
Normal file
@ -0,0 +1,215 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
|
||||
# define CUTE_ARCH_LDSM_SM75_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
struct SM75_U32x1_LDSM_N
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
|
||||
: "=r"(dst)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM75_U32x2_LDSM_N
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[2];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst0, uint32_t& dst1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
||||
: "=r"(dst0), "=r"(dst1)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM75_U32x4_LDSM_N
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM75_U16x2_LDSM_T
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
|
||||
: "=r"(dst)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM75_U16x4_LDSM_T
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[2];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst0, uint32_t& dst1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
||||
: "=r"(dst0), "=r"(dst1)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM75_U16x8_LDSM_T
|
||||
{
|
||||
using SRegisters = uint128_t[1];
|
||||
using DRegisters = uint32_t[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint128_t const& smem_src,
|
||||
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||
asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
|
||||
: "r"(smem_int_ptr));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Legacy LDSM interfaces that aren't very useful
|
||||
//
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_ldsm(uint128_t const* const smem_ptr,
|
||||
T* rmem_ptr)
|
||||
{
|
||||
uint32_t* reg_ptr = reinterpret_cast<uint32_t*>(rmem_ptr);
|
||||
|
||||
// if constexpr
|
||||
if (sizeof(T) == 4) {
|
||||
SM75_U32x1_LDSM_N::copy(smem_ptr[0], reg_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 8) {
|
||||
SM75_U32x2_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]);
|
||||
}
|
||||
else if (sizeof(T) == 16) {
|
||||
SM75_U32x4_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]);
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_ldsm_trans(uint128_t const* const smem_ptr,
|
||||
T* rmem_ptr)
|
||||
{
|
||||
uint32_t* reg_ptr = reinterpret_cast<uint32_t*>(rmem_ptr);
|
||||
|
||||
// if constexpr
|
||||
if (sizeof(T) == 4) {
|
||||
SM75_U16x2_LDSM_T::copy(smem_ptr[0], reg_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 8) {
|
||||
SM75_U16x4_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]);
|
||||
}
|
||||
else if (sizeof(T) == 16) {
|
||||
SM75_U16x8_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]);
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
138
include/cute/arch/copy_sm80.hpp
Normal file
138
include/cute/arch/copy_sm80.hpp
Normal file
@ -0,0 +1,138 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
# define CUTE_ARCH_CP_ASYNC_SM80_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
/// Copy via cp.async with caching at all levels
|
||||
template <class TS, class TD = TS>
|
||||
struct SM80_CP_ASYNC_CACHEALWAYS
|
||||
{
|
||||
using SRegisters = TS[1];
|
||||
using DRegisters = TD[1];
|
||||
|
||||
static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)");
|
||||
static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported");
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(TS const& gmem_src,
|
||||
TD & smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
TS const* gmem_ptr = &gmem_src;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"l"(gmem_ptr),
|
||||
"n"(sizeof(TS)));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/// Copy via cp.async with caching at global level
|
||||
template <class TS, class TD = TS>
|
||||
struct SM80_CP_ASYNC_CACHEGLOBAL
|
||||
{
|
||||
using SRegisters = TS[1];
|
||||
using DRegisters = TD[1];
|
||||
|
||||
static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)");
|
||||
static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported");
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(TS const& gmem_src,
|
||||
TD & smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
TS const* gmem_ptr = &gmem_src;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"l"(gmem_ptr),
|
||||
"n"(sizeof(TS)));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cp_async_fence()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Blocks until all but N previous cp.async.commit_group operations have committed.
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cp_async_wait()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
if constexpr (N == 0) {
|
||||
asm volatile("cp.async.wait_all;\n" ::);
|
||||
} else {
|
||||
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
cp_async_wait(Int<N>)
|
||||
{
|
||||
return cp_async_wait<N>();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
225
include/cute/arch/copy_sm90.hpp
Normal file
225
include/cute/arch/copy_sm90.hpp
Normal file
@ -0,0 +1,225 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
# define CUTE_ARCH_STSM_SM90_ENABLED
|
||||
# define CUTE_ARCH_TMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
struct SM90_U32x1_STSM_N
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src,
|
||||
uint128_t & smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U32x2_STSM_N
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src0, uint32_t const& src1,
|
||||
uint128_t& smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src0), "r"(src1));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U32x4_STSM_N
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3,
|
||||
uint128_t& smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src0), "r"(src1), "r"(src2), "r"(src3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U16x2_STSM_T
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src,
|
||||
uint128_t& smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U16x4_STSM_T
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src0, uint32_t const& src1,
|
||||
uint128_t& smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src0), "r"(src1));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U16x8_STSM_T
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3,
|
||||
uint128_t& smem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
asm volatile ("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(src0), "r"(src1), "r"(src2), "r"(src3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Legacy STSM interfaces that aren't very useful
|
||||
//
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_stsm(T const* const rmem_ptr,
|
||||
uint128_t* const smem_ptr)
|
||||
{
|
||||
uint32_t const* reg_ptr = reinterpret_cast<uint32_t const*>(rmem_ptr);
|
||||
|
||||
// if constexpr
|
||||
if (sizeof(T) == 4) {
|
||||
SM90_U32x1_STSM_N::copy(reg_ptr[0], smem_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 8) {
|
||||
SM90_U32x2_STSM_N::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 16) {
|
||||
SM90_U32x4_STSM_N::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]);
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy_stsm_trans(T const* const rmem_ptr,
|
||||
uint128_t* const smem_ptr)
|
||||
{
|
||||
uint32_t const* reg_ptr = reinterpret_cast<uint32_t const*>(rmem_ptr);
|
||||
|
||||
// if constexpr
|
||||
if (sizeof(T) == 4) {
|
||||
SM90_U16x2_STSM_T::copy(reg_ptr[0], smem_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 8) {
|
||||
SM90_U16x4_STSM_T::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]);
|
||||
}
|
||||
else if (sizeof(T) == 16) {
|
||||
SM90_U16x8_STSM_T::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]);
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
194
include/cute/arch/copy_sm90_desc.hpp
Normal file
194
include/cute/arch/copy_sm90_desc.hpp
Normal file
@ -0,0 +1,194 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
#include <cute/arch/copy_sm90.hpp>
|
||||
|
||||
#include <cute/container/alignment.hpp>
|
||||
#include <cute/container/bit_field.hpp>
|
||||
#include <cute/numeric/int.hpp> // to_Format<[u]intX>
|
||||
#include <cute/numeric/half.hpp> // to_Format<half_t>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns
|
||||
/// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels)
|
||||
/// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction)
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Initialize barrier present in shared memory
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
initialize_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem
|
||||
int thread_count = 1) // Thread count expected to arrive/wait on this barrier
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
|
||||
asm volatile ("mbarrier.init.shared.b64 [%0], %1;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(thread_count));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Set the number of bytes transfered per transaction
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem
|
||||
uint32_t bytes) // Number of bytes transfered by per TMA transaction
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
|
||||
asm volatile ("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(bytes));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Barrier wait
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
wait_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem
|
||||
int phase_bit) // Current phase bit the barrier waiting to flip
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"r"(phase_bit));
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
// Barrier arrive
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
arrive_barrier(uint64_t& smem_barrier) // 64 bits user-manged barrier in smem
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b64 state; \n"
|
||||
"mbarrier.arrive.shared.b64 state, [%0];\n"
|
||||
"}\n"
|
||||
:: "r"(smem_int_ptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// TMA Descriptor and utilities
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace TMA {
|
||||
|
||||
enum class SmemSwizzleBits : uint8_t {
|
||||
DISABLE = 0,
|
||||
B32 = 1,
|
||||
B64 = 2,
|
||||
B128 = 3,
|
||||
};
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
|
||||
template <class T>
|
||||
inline CUtensorMapDataType to_CUtensorMapDataType() {
|
||||
if constexpr (std::is_same<T, int8_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (std::is_same<T, uint8_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (std::is_same<T, uint16_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else
|
||||
if constexpr (std::is_same<T, uint32_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else
|
||||
if constexpr (std::is_same<T, uint64_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else
|
||||
if constexpr (std::is_same<T, int32_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else
|
||||
if constexpr (std::is_same<T, int64_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else
|
||||
if constexpr (std::is_same<T, half_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else
|
||||
if constexpr (std::is_same<T, float>::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else
|
||||
if constexpr (std::is_same<T, double>::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else
|
||||
if constexpr (std::is_same<T, bfloat16_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else
|
||||
if constexpr (std::is_same<T, tfloat32_t>::value) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else
|
||||
{ static_assert(sizeof(T) < 0, "Unknown TMA Format!"); }
|
||||
}
|
||||
|
||||
inline CUtensorMapSwizzle to_CUtensorMapSwizzle(SmemSwizzleBits const& t) {
|
||||
switch (t) {
|
||||
default: assert(false && "Unknown SmemSwizzleBits!");
|
||||
case SmemSwizzleBits::DISABLE: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case SmemSwizzleBits::B32: return CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
case SmemSwizzleBits::B64: return CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // (__CUDACC_VER_MAJOR__ >= 12)
|
||||
} // end namespace TMA
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
using TmaDescriptor = CUtensorMap;
|
||||
#else
|
||||
using TmaDescriptor = struct { char bytes[128]; };
|
||||
#endif
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Initiates a TensorMap Prefetch
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
prefetch_tma_descriptor(TmaDescriptor const* desc_ptr)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param)
|
||||
asm volatile (
|
||||
"prefetch.tensormap [%0];"
|
||||
:
|
||||
: "l"(gmem_int_desc)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
552
include/cute/arch/copy_sm90_tma.hpp
Normal file
552
include/cute/arch/copy_sm90_tma.hpp
Normal file
@ -0,0 +1,552 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
#include <cute/arch/copy_sm90.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM90_TMA_LOAD_1D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3}], [%2];"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_2D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3, %4}], [%2];"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_3D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3, %4, %5}], [%2];"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_4D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2];"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_5D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
return SM90_TMA_LOAD_1D::copy(desc_ptr, smem_mbar, smem_ptr, crd0);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
return SM90_TMA_LOAD_2D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
return SM90_TMA_LOAD_3D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
return SM90_TMA_LOAD_4D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
return SM90_TMA_LOAD_5D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3, crd4);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM90_TMA_LOAD_1D_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%4}], [%2], %3;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"h"(multicast_mask),
|
||||
"r"(crd0)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_2D_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%4, %5}], [%2], %3;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_3D_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%4, %5, %6}], [%2], %3;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_4D_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%4, %5, %6, %7}], [%2], %3;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_5D_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_LOAD_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
return SM90_TMA_LOAD_1D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
return SM90_TMA_LOAD_2D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
return SM90_TMA_LOAD_3D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
return SM90_TMA_LOAD_4D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
return SM90_TMA_LOAD_5D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// TMA_STORE : Initiates a TMA copy from shared memory to global memory
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM90_TMA_STORE_1D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(crd0)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_2D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(crd0), "r"(crd1)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_3D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_4D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE_5D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(gmem_int_desc), "r"(smem_int_ptr),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_TMA_STORE
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
return SM90_TMA_STORE_1D::copy(desc_ptr, smem_ptr, crd0);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
return SM90_TMA_STORE_2D::copy(desc_ptr, smem_ptr, crd0, crd1);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
return SM90_TMA_STORE_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
return SM90_TMA_STORE_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* const desc_ptr,
|
||||
void const* const smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
return SM90_TMA_STORE_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4);
|
||||
}
|
||||
};
|
||||
|
||||
// Indicate arrival of warp issuing TMA_STORE
|
||||
CUTE_HOST_DEVICE static void
|
||||
tma_store_arrive() {
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
asm volatile("cp.async.bulk.commit_group;");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Wait on prior N (Count) TMA_STORE instructions to complete
|
||||
template<int Count>
|
||||
CUTE_HOST_DEVICE static void
|
||||
tma_store_wait() {
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
asm volatile(
|
||||
"cp.async.bulk.wait_group.read %0;"
|
||||
:
|
||||
: "n"(Count)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
64
include/cute/arch/mma.hpp
Normal file
64
include/cute/arch/mma.hpp
Normal file
@ -0,0 +1,64 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/util.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Direct FMA for any type
|
||||
//
|
||||
|
||||
template <class D, class A = D, class B = A, class C = D>
|
||||
struct UniversalFMA
|
||||
{
|
||||
using DRegisters = D[1];
|
||||
using ARegisters = A[1];
|
||||
using BRegisters = B[1];
|
||||
using CRegisters = C[1];
|
||||
|
||||
CUTE_HOST_DEVICE static constexpr void
|
||||
fma(D & d,
|
||||
A const& a,
|
||||
B const& b,
|
||||
C const& c)
|
||||
{
|
||||
// Forward to an ADL/cute free function for these types
|
||||
using cute::fma;
|
||||
fma(d, a, b, c);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
87
include/cute/arch/mma_sm61.hpp
Normal file
87
include/cute/arch/mma_sm61.hpp
Normal file
@ -0,0 +1,87 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610))
|
||||
# define CUTE_ARCH_MMA_SM61_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
struct SM61_DP4A
|
||||
{
|
||||
using DRegisters = int32_t[1];
|
||||
using ARegisters = uint32_t[1];
|
||||
using BRegisters = uint32_t[1];
|
||||
using CRegisters = int32_t[1];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM61_ENABLED)
|
||||
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
||||
: "=r"(d)
|
||||
: "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM61_DP2A
|
||||
{
|
||||
using DRegisters = int32_t[1];
|
||||
using ARegisters = uint32_t[1];
|
||||
using BRegisters = uint32_t[1];
|
||||
using CRegisters = int32_t[1];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM61_ENABLED)
|
||||
asm volatile("dp2a.s32.s32 %0, %1, %2, %3;"
|
||||
: "=r"(d)
|
||||
: "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cute
|
||||
329
include/cute/arch/mma_sm70.hpp
Normal file
329
include/cute/arch/mma_sm70.hpp
Normal file
@ -0,0 +1,329 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
// Config
|
||||
#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))
|
||||
# define CUTE_ARCH_MMA_SM70_SUPPORTED
|
||||
# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
|
||||
# define CUTE_ARCH_MMA_SM70_ENABLED
|
||||
# endif
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// SM70 MMA 884 F16F16F16
|
||||
//
|
||||
|
||||
struct SM70_8x8x4_F16F16F16F16_TN
|
||||
{
|
||||
using DRegisters = uint32_t[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[4];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6, %7},"
|
||||
"{%8, %9, %10, %11};\n"
|
||||
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F16F16F16F16_NT
|
||||
{
|
||||
using DRegisters = uint32_t[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[4];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6, %7},"
|
||||
"{%8, %9, %10, %11};\n"
|
||||
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F16F16F16F16_NN
|
||||
{
|
||||
using DRegisters = uint32_t[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[4];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6, %7},"
|
||||
"{%8, %9, %10, %11};\n"
|
||||
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F16F16F16F16_TT
|
||||
{
|
||||
using DRegisters = uint32_t[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = uint32_t[4];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6, %7},"
|
||||
"{%8, %9, %10, %11};\n"
|
||||
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// SM70 MMA 884 F16F16F32
|
||||
//
|
||||
|
||||
struct SM70_8x8x4_F32F16F16F32_TN
|
||||
{
|
||||
using DRegisters = float[8];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = float[8];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float & d0, float & d1, float & d2, float & d3,
|
||||
float & d4, float & d5, float & d6, float & d7,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
float const& c0, float const& c1, float const& c2, float const& c3,
|
||||
float const& c4, float const& c5, float const& c6, float const& c7)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11},"
|
||||
"{%12, %13, %14, %15, %16, %17, %18, %19};\n"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
|
||||
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
|
||||
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F32F16F16F32_NT
|
||||
{
|
||||
using DRegisters = float[8];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = float[8];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float & d0, float & d1, float & d2, float & d3,
|
||||
float & d4, float & d5, float & d6, float & d7,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
float const& c0, float const& c1, float const& c2, float const& c3,
|
||||
float const& c4, float const& c5, float const& c6, float const& c7)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11},"
|
||||
"{%12, %13, %14, %15, %16, %17, %18, %19};"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
|
||||
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
|
||||
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F32F16F16F32_NN
|
||||
{
|
||||
using DRegisters = float[8];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = float[8];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float & d0, float & d1, float & d2, float & d3,
|
||||
float & d4, float & d5, float & d6, float & d7,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
float const& c0, float const& c1, float const& c2, float const& c3,
|
||||
float const& c4, float const& c5, float const& c6, float const& c7)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11},"
|
||||
"{%12, %13, %14, %15, %16, %17, %18, %19};"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
|
||||
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
|
||||
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM70_8x8x4_F32F16F16F32_TT
|
||||
{
|
||||
using DRegisters = float[8];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[2];
|
||||
using CRegisters = float[8];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float & d0, float & d1, float & d2, float & d3,
|
||||
float & d4, float & d5, float & d6, float & d7,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0, uint32_t const& b1,
|
||||
float const& c0, float const& c1, float const& c2, float const& c3,
|
||||
float const& c4, float const& c5, float const& c6, float const& c7)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11},"
|
||||
"{%12, %13, %14, %15, %16, %17, %18, %19};"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
|
||||
"=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0), "r"(b1),
|
||||
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
|
||||
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED");
|
||||
#endif
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
120
include/cute/arch/mma_sm75.hpp
Normal file
120
include/cute/arch/mma_sm75.hpp
Normal file
@ -0,0 +1,120 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
// Config
|
||||
#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))
|
||||
# define CUTE_ARCH_MMA_SM75_SUPPORTED
|
||||
# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
|
||||
# define CUTE_ARCH_MMA_SM75_ENABLED
|
||||
# endif
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// SM75 MMA 1688 F16F16F32
|
||||
//
|
||||
|
||||
struct SM75_16x8x8_F32F16F16F32_TN
|
||||
{
|
||||
using DRegisters = float[4];
|
||||
using ARegisters = uint32_t[2];
|
||||
using BRegisters = uint32_t[1];
|
||||
using CRegisters = float[4];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(float & d0, float & d1, float & d2, float & d3,
|
||||
uint32_t const& a0, uint32_t const& a1,
|
||||
uint32_t const& b0,
|
||||
float const& c0, float const& c1, float const& c2, float const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM75_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6},"
|
||||
"{%7, %8, %9, %10};\n"
|
||||
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
|
||||
: "r"(a0), "r"(a1),
|
||||
"r"(b0),
|
||||
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// SM75 MMA 8816 S8S8S32
|
||||
//
|
||||
|
||||
struct SM75_8x8x16_S32S8S8S32_TN
|
||||
{
|
||||
using DRegisters = uint32_t[2];
|
||||
using ARegisters = uint32_t[1];
|
||||
using BRegisters = uint32_t[1];
|
||||
using CRegisters = uint32_t[2];
|
||||
|
||||
// Register asm fma
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(uint32_t & d0, uint32_t & d1,
|
||||
uint32_t const& a0,
|
||||
uint32_t const& b0,
|
||||
uint32_t const& c0, uint32_t const& c1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM75_ENABLED)
|
||||
asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32"
|
||||
"{%0, %1},"
|
||||
"{%2},"
|
||||
"{%3},"
|
||||
"{%4, %5};\n"
|
||||
: "=r"(d0), "=r"(d1)
|
||||
: "r"(a0),
|
||||
"r"(b0),
|
||||
"r"(c0), "r"(c1));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
2132
include/cute/arch/mma_sm80.hpp
Normal file
2132
include/cute/arch/mma_sm80.hpp
Normal file
File diff suppressed because it is too large
Load Diff
961
include/cute/arch/mma_sm90.hpp
Normal file
961
include/cute/arch/mma_sm90.hpp
Normal file
@ -0,0 +1,961 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL))
|
||||
# define CUTE_ARCH_MMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x4 TN
|
||||
struct SM90_16x8x4_F64F64F64F64_TN
|
||||
{
|
||||
using DRegisters = double[4];
|
||||
using ARegisters = double[2];
|
||||
using BRegisters = double[1];
|
||||
using CRegisters = double[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(double & d0, double & d1, double & d2, double & d3,
|
||||
double const& a0, double const& a1,
|
||||
double const& b0,
|
||||
double const& c0, double const& c1, double const& c2, double const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM90_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5},"
|
||||
"{%6},"
|
||||
"{%7, %8, %9, %10};\n"
|
||||
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
|
||||
: "d"(a0), "d"(a1),
|
||||
"d"(b0),
|
||||
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x8 TN
|
||||
struct SM90_16x8x8_F64F64F64F64_TN
|
||||
{
|
||||
using DRegisters = double[4];
|
||||
using ARegisters = double[4];
|
||||
using BRegisters = double[2];
|
||||
using CRegisters = double[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(double & d0, double & d1, double & d2, double & d3,
|
||||
double const& a0, double const& a1, double const& a2, double const& a3,
|
||||
double const& b0, double const& b1,
|
||||
double const& c0, double const& c1, double const& c2, double const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM90_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5, %6, %7},"
|
||||
"{%8, %9},"
|
||||
"{%10, %11, %12, %13};\n"
|
||||
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
|
||||
: "d"(a0), "d"(a1), "d"(a2), "d"(a3),
|
||||
"d"(b0), "d"(b1),
|
||||
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x16 TN
|
||||
struct SM90_16x8x16_F64F64F64F64_TN
|
||||
{
|
||||
using DRegisters = double[4];
|
||||
using ARegisters = double[8];
|
||||
using BRegisters = double[4];
|
||||
using CRegisters = double[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(double & d0, double & d1, double & d2, double & d3,
|
||||
double const& a0, double const& a1, double const& a2, double const& a3,
|
||||
double const& a4, double const& a5, double const& a6, double const& a7,
|
||||
double const& b0, double const& b1, double const& b2, double const& b3,
|
||||
double const& c0, double const& c1, double const& c2, double const& c3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MMA_SM90_ENABLED)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64"
|
||||
"{%0, %1, %2, %3},"
|
||||
"{%4, %5, %6, %7, %8, %9, %10, %11},"
|
||||
"{%12, %13, %14, %15},"
|
||||
"{%16, %17, %18, %19};\n"
|
||||
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
|
||||
: "d"(a0), "d"(a1), "d"(a2), "d"(a3),
|
||||
"d"(a4), "d"(a5), "d"(a6), "d"(a7),
|
||||
"d"(b0), "d"(b1), "d"(b2), "d"(b3),
|
||||
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x4 TN
|
||||
struct SM90_16x8x4_C64C64C64C64_TN
|
||||
{
|
||||
using DRegisters = complex<double>[4];
|
||||
using ARegisters = complex<double>[2];
|
||||
using BRegisters = complex<double>[1];
|
||||
using CRegisters = complex<double>[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(complex<double> & d0, complex<double> & d1,
|
||||
complex<double> & d2, complex<double> & d3,
|
||||
complex<double> const& a0, complex<double> const& a1,
|
||||
complex<double> const& b0,
|
||||
complex<double> const& c0, complex<double> const& c1,
|
||||
complex<double> const& c2, complex<double> const& c3)
|
||||
{
|
||||
// Because thrust::complex does not provide a mutable ref
|
||||
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
|
||||
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
|
||||
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
|
||||
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
|
||||
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
|
||||
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
|
||||
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
|
||||
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
|
||||
|
||||
// d.real() = a.real() * b.real() + c.real();
|
||||
SM90_16x8x4_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
a0.real(), a1.real(),
|
||||
b0.real(),
|
||||
c0.real(), c1.real(), c2.real(), c3.real());
|
||||
|
||||
// d.imag() = a.imag() * b.real() + c.imag();
|
||||
SM90_16x8x4_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.imag(), a1.imag(),
|
||||
b0.real(),
|
||||
c0.imag(), c1.imag(), c2.imag(), c3.imag());
|
||||
|
||||
// d.real() = -a.imag() * b.imag() + d.real();
|
||||
SM90_16x8x4_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
-a0.imag(), -a1.imag(),
|
||||
b0.imag(),
|
||||
d0.real(), d1.real(), d2.real(), d3.real());
|
||||
|
||||
// d.imag() = a.real() * b.imag() + d.imag();
|
||||
SM90_16x8x4_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.real(), a1.real(),
|
||||
b0.imag(),
|
||||
d0.imag(), d1.imag(), d2.imag(), d3.imag());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x8 TN
|
||||
struct SM90_16x8x8_C64C64C64C64_TN
|
||||
{
|
||||
using DRegisters = complex<double>[4];
|
||||
using ARegisters = complex<double>[4];
|
||||
using BRegisters = complex<double>[2];
|
||||
using CRegisters = complex<double>[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(complex<double> & d0, complex<double> & d1,
|
||||
complex<double> & d2, complex<double> & d3,
|
||||
complex<double> const& a0, complex<double> const& a1,
|
||||
complex<double> const& a2, complex<double> const& a3,
|
||||
complex<double> const& b0, complex<double> const& b1,
|
||||
complex<double> const& c0, complex<double> const& c1,
|
||||
complex<double> const& c2, complex<double> const& c3)
|
||||
{
|
||||
// Because thrust::complex does not provide a mutable ref
|
||||
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
|
||||
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
|
||||
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
|
||||
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
|
||||
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
|
||||
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
|
||||
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
|
||||
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
|
||||
|
||||
// d.real() = a.real() * b.real() + c.real();
|
||||
SM90_16x8x8_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
a0.real(), a1.real(), a2.real(), a3.real(),
|
||||
b0.real(), b1.real(),
|
||||
c0.real(), c1.real(), c2.real(), c3.real());
|
||||
|
||||
// d.imag() = a.imag() * b.real() + c.imag();
|
||||
SM90_16x8x8_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.imag(), a1.imag(), a2.imag(), a3.imag(),
|
||||
b0.real(), b1.real(),
|
||||
c0.imag(), c1.imag(), c2.imag(), c3.imag());
|
||||
|
||||
// d.real() = -a.imag() * b.imag() + d.real();
|
||||
SM90_16x8x8_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
-a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(),
|
||||
b0.imag(), b1.imag(),
|
||||
d0.real(), d1.real(), d2.real(), d3.real());
|
||||
|
||||
// d.imag() = a.real() * b.imag() + d.imag();
|
||||
SM90_16x8x8_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.real(), a1.real(), a2.real(), a3.real(),
|
||||
b0.imag(), b1.imag(),
|
||||
d0.imag(), d1.imag(), d2.imag(), d3.imag());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// MMA 16x8x16 TN
|
||||
struct SM90_16x8x16_C64C64C64C64_TN
|
||||
{
|
||||
using DRegisters = complex<double>[4];
|
||||
using ARegisters = complex<double>[8];
|
||||
using BRegisters = complex<double>[4];
|
||||
using CRegisters = complex<double>[4];
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
fma(complex<double> & d0, complex<double> & d1,
|
||||
complex<double> & d2, complex<double> & d3,
|
||||
complex<double> const& a0, complex<double> const& a1,
|
||||
complex<double> const& a2, complex<double> const& a3,
|
||||
complex<double> const& a4, complex<double> const& a5,
|
||||
complex<double> const& a6, complex<double> const& a7,
|
||||
complex<double> const& b0, complex<double> const& b1,
|
||||
complex<double> const& b2, complex<double> const& b3,
|
||||
complex<double> const& c0, complex<double> const& c1,
|
||||
complex<double> const& c2, complex<double> const& c3)
|
||||
{
|
||||
// Because thrust::complex does not provide a mutable ref
|
||||
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
|
||||
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
|
||||
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
|
||||
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
|
||||
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
|
||||
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
|
||||
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
|
||||
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
|
||||
|
||||
// d.real() = a.real() * b.real() + c.real();
|
||||
SM90_16x8x16_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
a0.real(), a1.real(), a2.real(), a3.real(),
|
||||
a4.real(), a5.real(), a6.real(), a7.real(),
|
||||
b0.real(), b1.real(), b2.real(), b3.real(),
|
||||
c0.real(), c1.real(), c2.real(), c3.real());
|
||||
|
||||
// d.imag() = a.imag() * b.real() + c.imag();
|
||||
SM90_16x8x16_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.imag(), a1.imag(), a2.imag(), a3.imag(),
|
||||
a4.imag(), a5.imag(), a6.imag(), a7.imag(),
|
||||
b0.real(), b1.real(), b2.real(), b3.real(),
|
||||
c0.imag(), c1.imag(), c2.imag(), c3.imag());
|
||||
|
||||
// d.real() = -a.imag() * b.imag() + d.real();
|
||||
SM90_16x8x16_F64F64F64F64_TN::fma(
|
||||
rd0, rd1, rd2, rd3,
|
||||
-a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(),
|
||||
-a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(),
|
||||
b0.imag(), b1.imag(), b2.imag(), b3.imag(),
|
||||
d0.real(), d1.real(), d2.real(), d3.real());
|
||||
|
||||
// d.imag() = a.real() * b.imag() + d.imag();
|
||||
SM90_16x8x16_F64F64F64F64_TN::fma(
|
||||
id0, id1, id2, id3,
|
||||
a0.real(), a1.real(), a2.real(), a3.real(),
|
||||
a4.real(), a5.real(), a6.real(), a7.real(),
|
||||
b0.imag(), b1.imag(), b2.imag(), b3.imag(),
|
||||
d0.imag(), d1.imag(), d2.imag(), d3.imag());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <cute/arch/mma_sm90_desc.hpp>
|
||||
#include <cute/arch/mma_sm90_gmma.hpp>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cute {
|
||||
namespace GMMA {
|
||||
|
||||
template<
|
||||
class ElementA,
|
||||
class ElementB,
|
||||
class ElementC,
|
||||
class TileShape_MNK,
|
||||
GMMA::Major MajorA = GMMA::Major::K,
|
||||
GMMA::Major MajorB = GMMA::Major::K,
|
||||
auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One]
|
||||
// But most commonly leave empty for defaults
|
||||
>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
ss_op_selector()
|
||||
{
|
||||
static_assert(is_static<TileShape_MNK>::value, "TileShape_MNK must be static.");
|
||||
static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3.");
|
||||
static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64.");
|
||||
auto Tile_N = size<1>(TileShape_MNK{});
|
||||
|
||||
// FP16 accumulator
|
||||
if constexpr (std::is_same_v<ElementC, half_t>) {
|
||||
static_assert(std::is_same_v<ElementA, half_t>, "Element types for AB must be half if ElementC is half.");
|
||||
static_assert(std::is_same_v<ElementB, half_t>, "Element types for AB must be half if ElementC is half.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
// Dispatch against the Tile N mode size
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// FP32 accumulator
|
||||
else if constexpr (std::is_same_v<ElementC, float>) {
|
||||
|
||||
// FP16 inputs
|
||||
if constexpr (std::is_same_v<ElementA, half_t>) {
|
||||
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// BF16 inputs
|
||||
else if constexpr (std::is_same_v<ElementA, bfloat16_t>) {
|
||||
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// TF32 inputs
|
||||
else if constexpr (std::is_same_v<ElementA, tfloat32_t>) {
|
||||
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
||||
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x8_F32TF32TF32_SS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration.");
|
||||
}
|
||||
}
|
||||
|
||||
// S32 accumulator
|
||||
else if constexpr (std::is_same_v<ElementC, int32_t>) {
|
||||
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
||||
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
// ElementA == int8_t && ElementB == int8_t
|
||||
if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, int8_t>) {
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32S8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == int8_t && ElementB == uint8_t
|
||||
else if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, uint8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32S8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == uint8_t && ElementB == int8_t
|
||||
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, int8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32U8S8_SS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == uint8_t && ElementB == uint8_t
|
||||
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, uint8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32U8U8_SS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unknown accumulator type
|
||||
else {
|
||||
static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type.");
|
||||
}
|
||||
}
|
||||
|
||||
template<
|
||||
class ElementA,
|
||||
class ElementB,
|
||||
class ElementC,
|
||||
class TileShape_MNK,
|
||||
GMMA::Major MajorA = GMMA::Major::K,
|
||||
GMMA::Major MajorB = GMMA::Major::K,
|
||||
auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One]
|
||||
// But most commonly leave empty for defaults
|
||||
>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
rs_op_selector()
|
||||
{
|
||||
static_assert(is_static<TileShape_MNK>::value, "TileShape_MNK must be static.");
|
||||
static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3.");
|
||||
static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64.");
|
||||
static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout.");
|
||||
auto Tile_N = size<1>(TileShape_MNK{});
|
||||
|
||||
// FP16 accumulator
|
||||
if constexpr (std::is_same_v<ElementC, half_t>) {
|
||||
static_assert(std::is_same_v<ElementA, half_t>, "Element types for AB must be half if ElementC is half.");
|
||||
static_assert(std::is_same_v<ElementB, half_t>, "Element types for AB must be half if ElementC is half.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
// Dispatch against the Tile N mode size
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// FP32 accumulator
|
||||
else if constexpr (std::is_same_v<ElementC, float>) {
|
||||
static_assert(std::is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
// FP16 inputs
|
||||
if constexpr (std::is_same_v<ElementA, half_t>) {
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// BF16 inputs
|
||||
else if constexpr (std::is_same_v<ElementA, bfloat16_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// TF32 inputs
|
||||
else if constexpr (std::is_same_v<ElementA, tfloat32_t>) {
|
||||
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x8_F32TF32TF32_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration.");
|
||||
}
|
||||
}
|
||||
|
||||
// S32 accumulator
|
||||
else if constexpr (std::is_same_v<ElementC, int32_t>) {
|
||||
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
// ElementA == int8_t && ElementB == int8_t
|
||||
if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, int8_t>) {
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32S8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == int8_t && ElementB == uint8_t
|
||||
else if constexpr (std::is_same_v<ElementA, int8_t> && std::is_same_v<ElementB, uint8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32S8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == uint8_t && ElementB == int8_t
|
||||
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, int8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32U8S8_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
|
||||
// ElementA == uint8_t && ElementB == uint8_t
|
||||
else if constexpr (std::is_same_v<ElementA, uint8_t> && std::is_same_v<ElementB, uint8_t>) {
|
||||
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
||||
|
||||
if constexpr (Tile_N % 256 == 0) {
|
||||
return SM90_64x256x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 192 == 0) {
|
||||
return SM90_64x192x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 128 == 0) {
|
||||
return SM90_64x128x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 96 == 0) {
|
||||
return SM90_64x96x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 64 == 0) {
|
||||
return SM90_64x64x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 32 == 0) {
|
||||
return SM90_64x32x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 16 == 0) {
|
||||
return SM90_64x16x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else if constexpr (Tile_N % 8 == 0) {
|
||||
return SM90_64x8x32_S32U8U8_RS_TN<Args...>{};
|
||||
}
|
||||
else {
|
||||
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unknown accumulator type
|
||||
else {
|
||||
static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type.");
|
||||
}
|
||||
}
|
||||
} // end namespace GMMA
|
||||
} // end namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
131
include/cute/arch/mma_sm90_desc.hpp
Normal file
131
include/cute/arch/mma_sm90_desc.hpp
Normal file
@ -0,0 +1,131 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
// Config
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL))
|
||||
# define CUTE_ARCH_MMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// GMMA Descriptor and utilities
|
||||
|
||||
// GMMA enums and utilities
|
||||
namespace GMMA
|
||||
{
|
||||
|
||||
enum class LayoutType : uint8_t {
|
||||
INTERLEAVE = 0,
|
||||
B128 = 1,
|
||||
B64 = 2,
|
||||
B32 = 3,
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) {
|
||||
switch (t) {
|
||||
case LayoutType::INTERLEAVE: return "INTERLEAVE";
|
||||
case LayoutType::B128: return "B128";
|
||||
case LayoutType::B64: return "B64";
|
||||
case LayoutType::B32: return "B32";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Output operator for all enums in this namespace
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) {
|
||||
char const* s = to_string(t);
|
||||
if (s) {
|
||||
std::operator<<(os, s); // Explicit call to avoid ambiguity
|
||||
} else {
|
||||
os.setstate(std::ios_base::failbit);
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
} // end namespace GMMA
|
||||
|
||||
union GmmaDescriptor
|
||||
{
|
||||
uint64_t desc_;
|
||||
uint32_t reg32_[2];
|
||||
uint16_t reg16_[4];
|
||||
|
||||
// Bitfield implementation avoids the need for shifts in assignment
|
||||
struct {
|
||||
// start_address, bit [0,14), 4LSB not included
|
||||
uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused
|
||||
// leading dimension byte offset, bit [16,30), 4LSB not included
|
||||
// For N: This is the stride from the first col to the second col of the 8x2 brick in INTERLEAVED
|
||||
// Unused for all SWIZZLE_* layouts (and assumed to be 1)
|
||||
// For T: This is the stride from the first 8 rows to the next 8 rows.
|
||||
uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
|
||||
// stride dimension byte offset, bit [32,46), 4LSB not included
|
||||
// For N: This is the stride from the first 8 rows to the next 8 rows.
|
||||
// For T: This is the stride fro mthe first 8 cols to the next 8 cols.
|
||||
uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
|
||||
// base_offset, bit [49,52)
|
||||
// Valid only for SWIZZLE_128B and SWIZZLE_64B
|
||||
uint8_t : 1, base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused
|
||||
// layout type, bit [62,64)
|
||||
// SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1
|
||||
uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8)
|
||||
};
|
||||
|
||||
// Decay to a uint64_t
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator uint64_t() const noexcept { return desc_; }
|
||||
|
||||
// Printer
|
||||
CUTE_HOST_DEVICE friend void print(GmmaDescriptor const& t)
|
||||
{
|
||||
printf("GmmaDescriptor: 0x%016lx\n", t.desc_);
|
||||
printf(" start_addr : 0x%04x\n", t.start_address_);
|
||||
printf(" leading_off: 0x%04x (%d)\n", t.leading_byte_offset_, t.leading_byte_offset_);
|
||||
printf(" stride_off : 0x%04x (%d)\n", t.stride_byte_offset_, t.stride_byte_offset_);
|
||||
printf(" base_offset: 0x%01x\n", t.base_offset_);
|
||||
printf(" layout_type: 0x%01x (%s)\n", t.layout_type_, to_string(static_cast<GMMA::LayoutType>(t.layout_type_)));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
12265
include/cute/arch/mma_sm90_gmma.hpp
Normal file
12265
include/cute/arch/mma_sm90_gmma.hpp
Normal file
File diff suppressed because it is too large
Load Diff
178
include/cute/arch/util.hpp
Normal file
178
include/cute/arch/util.hpp
Normal file
@ -0,0 +1,178 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/numeric/integer_sequence.hpp>
|
||||
|
||||
#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
|
||||
extern "C" {
|
||||
// This NVVM intrinsic is subject to change in future versions of CUDA.
|
||||
// Clients should not call it directly.
|
||||
CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*);
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
/// CUTE helper to cast SMEM pointer to unsigned
|
||||
CUTE_HOST_DEVICE
|
||||
uint32_t
|
||||
cast_smem_ptr_to_uint(void const* const ptr)
|
||||
{
|
||||
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
|
||||
// the previous internal intrinsics if they are available.
|
||||
#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
|
||||
//
|
||||
// This NVVM intrinsic converts an address in shared memory to a plain
|
||||
// unsigned integer. This is necessary to pass to shared memory instructions
|
||||
// in inline PTX.
|
||||
//
|
||||
// In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2].
|
||||
//
|
||||
//__device__ size_t __cvta_generic_to_shared(void* ptr);
|
||||
|
||||
/// CUTE helper to get SMEM pointer
|
||||
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
|
||||
|
||||
return __nvvm_get_smem_pointer(ptr);
|
||||
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
|
||||
uint32_t smem_ptr;
|
||||
|
||||
asm(
|
||||
"{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
|
||||
: "=r"(smem_ptr) : "l"(ptr));
|
||||
|
||||
return smem_ptr;
|
||||
|
||||
#else
|
||||
|
||||
|
||||
(void) ptr;
|
||||
printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n");
|
||||
return 0;
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// Utility for pointer interfaces
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class Fn,
|
||||
class PtrS, int... Is,
|
||||
class PtrD, int... Id>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn,
|
||||
PtrS&& s, int_sequence<Is...>,
|
||||
PtrD&& d, int_sequence<Id...>)
|
||||
{
|
||||
return fn(s[Is]..., d[Id]...);
|
||||
}
|
||||
|
||||
template <class Fn,
|
||||
class PtrA, int... Ia,
|
||||
class PtrB, int... Ib,
|
||||
class PtrC, int... Ic>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn,
|
||||
PtrA&& a, int_sequence<Ia...>,
|
||||
PtrB&& b, int_sequence<Ib...>,
|
||||
PtrC&& c, int_sequence<Ic...>)
|
||||
{
|
||||
return fn(a[Ia]..., b[Ib]..., c[Ic]...);
|
||||
}
|
||||
|
||||
template <class Fn,
|
||||
class PtrD, int... Id,
|
||||
class PtrA, int... Ia,
|
||||
class PtrB, int... Ib,
|
||||
class PtrC, int... Ic>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn,
|
||||
PtrD&& d, int_sequence<Id...>,
|
||||
PtrA&& a, int_sequence<Ia...>,
|
||||
PtrB&& b, int_sequence<Ib...>,
|
||||
PtrC&& c, int_sequence<Ic...>)
|
||||
{
|
||||
return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <int SRegCount, int DRegCount,
|
||||
class Fn, class PtrS, class PtrD>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn, PtrS&& s, PtrD&& d)
|
||||
{
|
||||
return detail::explode(fn,
|
||||
s, make_int_sequence<SRegCount>{},
|
||||
d, make_int_sequence<DRegCount>{});
|
||||
}
|
||||
|
||||
template <int ARegCount, int BRegCount, int CRegCount,
|
||||
class Fn, class PtrA, class PtrB, class PtrC>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn, PtrA&& a, PtrB&& b, PtrC&& c)
|
||||
{
|
||||
return detail::explode(fn,
|
||||
a, make_int_sequence<ARegCount>{},
|
||||
b, make_int_sequence<BRegCount>{},
|
||||
c, make_int_sequence<CRegCount>{});
|
||||
}
|
||||
|
||||
template <int DRegCount, int ARegCount, int BRegCount, int CRegCount,
|
||||
class Fn, class PtrD, class PtrA, class PtrB, class PtrC>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
explode(Fn fn, PtrD&& d, PtrA&& a, PtrB&& b, PtrC&& c)
|
||||
{
|
||||
return detail::explode(fn,
|
||||
d, make_int_sequence<DRegCount>{},
|
||||
a, make_int_sequence<ARegCount>{},
|
||||
b, make_int_sequence<BRegCount>{},
|
||||
c, make_int_sequence<CRegCount>{});
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
671
include/cute/atom/copy_atom.hpp
Normal file
671
include/cute/atom/copy_atom.hpp
Normal file
@ -0,0 +1,671 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
// Generic copy_unpack for any Copy_Traits
|
||||
template <class Operation, class... Args,
|
||||
class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
copy_unpack(Copy_Traits<Operation, Args...> const&,
|
||||
Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst)
|
||||
{
|
||||
// Specializations can generalize on these checks
|
||||
//static_assert(is_smem<TS>::value, "Expected smem for this Copy_Traits<Operation>");
|
||||
//static_assert(is_rmem<TD>::value, "Expected rmem for this Copy_Traits<Operation>");
|
||||
|
||||
using RegistersSrc = typename Operation::SRegisters;
|
||||
using RegistersDst = typename Operation::DRegisters;
|
||||
using RegTypeSrc = typename std::remove_extent<RegistersSrc>::type;
|
||||
using RegTypeDst = typename std::remove_extent<RegistersDst>::type;
|
||||
constexpr int RegNumSrc = std::extent<RegistersSrc>::value;
|
||||
constexpr int RegNumDst = std::extent<RegistersDst>::value;
|
||||
|
||||
Tensor rS = recast<RegTypeSrc>(src);
|
||||
Tensor rD = recast<RegTypeDst>(dst);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(rS) == Int<RegNumSrc>{},
|
||||
"In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy.");
|
||||
CUTE_STATIC_ASSERT_V(size(rD) == Int<RegNumDst>{},
|
||||
"In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy.");
|
||||
|
||||
detail::explode(Operation::copy,
|
||||
rS, make_int_sequence<RegNumSrc>{},
|
||||
rD, make_int_sequence<RegNumDst>{});
|
||||
}
|
||||
|
||||
|
||||
template <class... Args>
|
||||
struct Copy_Atom;
|
||||
|
||||
template <class CopyOperation, class T>
|
||||
struct Copy_Atom<CopyOperation, T> : Copy_Atom<Copy_Traits<CopyOperation>, T>
|
||||
{};
|
||||
|
||||
template <class... Args, class T>
|
||||
struct Copy_Atom<Copy_Traits<Args...>, T>
|
||||
: Copy_Traits<Args...>
|
||||
{
|
||||
using Traits = Copy_Traits<Args...>;
|
||||
|
||||
// Bit and Thr layouts from the Copy_Traits
|
||||
using ThrID = typename Traits::ThrID;
|
||||
using BitLayoutSrc = typename Traits::SrcLayout;
|
||||
using BitLayoutDst = typename Traits::DstLayout;
|
||||
using BitLayoutRef = typename Traits::RefLayout;
|
||||
|
||||
using ValType = T;
|
||||
|
||||
using ValLayoutSrc = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutSrc{}));
|
||||
using ValLayoutDst = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutDst{}));
|
||||
using ValLayoutRef = decltype(upcast<sizeof_bits<ValType>::value>(BitLayoutRef{}));
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType.");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType.");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType.");
|
||||
|
||||
static constexpr int NumValSrc = size<1>(ValLayoutSrc{});
|
||||
static constexpr int NumValDst = size<1>(ValLayoutDst{});
|
||||
|
||||
// Additional Trait parameters/transformations
|
||||
template <class... TraitsArgs>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
with(TraitsArgs&&... args) const {
|
||||
auto traits = Traits::with(std::forward<TraitsArgs>(args)...);
|
||||
return Copy_Atom<decltype(traits), T>{traits};
|
||||
}
|
||||
|
||||
// Print thread and data layouts for debugging
|
||||
CUTE_HOST_DEVICE static
|
||||
void
|
||||
print_all()
|
||||
{
|
||||
print("ThrID: "); print(ThrID{}); print("\n");
|
||||
print("BitLayoutSrc: "); print(BitLayoutSrc{}); print("\n");
|
||||
print("BitLayoutDst: "); print(BitLayoutDst{}); print("\n");
|
||||
print("BitLayoutRef: "); print(BitLayoutRef{}); print("\n");
|
||||
print("ValLayoutSrc: "); print(ValLayoutSrc{}); print("\n");
|
||||
print("ValLayoutDst: "); print(ValLayoutDst{}); print("\n");
|
||||
print("ValLayoutRef: "); print(ValLayoutRef{}); print("\n");
|
||||
print("ValueType: %db", sizeof_bits<ValType>::value); print("\n");
|
||||
}
|
||||
|
||||
//
|
||||
// Tensor call interfaces
|
||||
//
|
||||
|
||||
// Cast, check, and call
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
call(Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst) const
|
||||
{
|
||||
static_assert(SLayout::rank == 1, "Expected rank-1 src tensor");
|
||||
static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor");
|
||||
|
||||
if constexpr (is_constant<NumValSrc, decltype(size(src))>::value || is_constant<NumValDst, decltype(size(dst))>::value) {
|
||||
// Dispatch to unpack for instruction
|
||||
return copy_unpack(*this, src, dst);
|
||||
} else {
|
||||
// Recurse if needed by peeling the tensor mode
|
||||
return copy(*this, tensor<0>(src), tensor<0>(dst));
|
||||
}
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template <class SEngine, class SLayout,
|
||||
class DEngine, class DLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
call(Tensor<SEngine,SLayout> const& src,
|
||||
Tensor<DEngine,DLayout> && dst) const
|
||||
{
|
||||
return call(src, dst);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// A tiling of copy atoms
|
||||
//
|
||||
|
||||
template <class Copy_Atom,
|
||||
class LayoutCopy_TV, // (tid,vid) -> coord [Need not be 2D...]
|
||||
class ShapeTile_MN> // coord space
|
||||
struct TiledCopy : Copy_Atom
|
||||
{
|
||||
// Layout information from the CopyAtom
|
||||
using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx
|
||||
using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset
|
||||
using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset
|
||||
using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset
|
||||
|
||||
using AtomNumThr = decltype(size<0>(AtomLayoutRef{}));
|
||||
using AtomNumVal = decltype(size<1>(AtomLayoutRef{}));
|
||||
|
||||
// Layout information for the TiledCopy
|
||||
using Tiler_MN = ShapeTile_MN;
|
||||
using TiledShape_MN = decltype(shape(ShapeTile_MN{}));
|
||||
using TiledLayout_TV = LayoutCopy_TV;
|
||||
using TiledNumThr = decltype(size<0>(TiledLayout_TV{}));
|
||||
using TiledNumVal = decltype(size<1>(TiledLayout_TV{}));
|
||||
|
||||
CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom");
|
||||
CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom");
|
||||
|
||||
// Tile a tensor or a layout from shape
|
||||
// (M,N,...)
|
||||
// to shape
|
||||
// ((ThrV,ThrX),FrgV,(RestM,RestN,...))
|
||||
// where
|
||||
// ThrV: The threads local to a COPY_ATOM Src.
|
||||
// ThrX: The threads tiled across COPY_ATOMs Src.
|
||||
// FrgV: The values local to a COPY_ATOM Src.
|
||||
// RestM: The values tiled in M.
|
||||
// RestN: The values tiled in N.
|
||||
template <class STensor>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
tidfrg_S(STensor&& stensor)
|
||||
{
|
||||
return thrfrg(stensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}));
|
||||
}
|
||||
|
||||
// Tile a tensor or a layout from shape
|
||||
// (M,N,...)
|
||||
// to shape
|
||||
// ((ThrV,ThrX),FrgV,(RestM,RestN,...))
|
||||
// where
|
||||
// ThrV: The threads local to a COPY_ATOM Dst.
|
||||
// ThrX: The threads tiled across COPY_ATOMs Dst.
|
||||
// FrgV: The values local to a COPY_ATOM Dst.
|
||||
// RestM: The values tiled in M.
|
||||
// RestN: The values tiled in N.
|
||||
template <class DTensor>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
tidfrg_D(DTensor&& dtensor)
|
||||
{
|
||||
return thrfrg(dtensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}));
|
||||
}
|
||||
|
||||
template <class Tensor, class Ref2TrgLayout>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg)
|
||||
{
|
||||
constexpr int R = remove_cvref_t<Tensor>::rank;
|
||||
static_assert(R >= rank_v<TiledShape_MN>, "Rank of tensor to be partitioned too small.");
|
||||
// Generalize the dimension checks for arbitrary rank
|
||||
//CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{});
|
||||
//CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{});
|
||||
|
||||
// Take the thrs/vals that the atom is interested in
|
||||
// NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID
|
||||
auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{}));
|
||||
// ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n)
|
||||
|
||||
// Transform to the trg layout
|
||||
auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _);
|
||||
// ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n)
|
||||
|
||||
// Transform the thrs mode from thrid to thr_idx
|
||||
// NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID
|
||||
auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{});
|
||||
// ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n)
|
||||
|
||||
/// ==================
|
||||
|
||||
// Tile the tensor for TiledLayout
|
||||
auto t_tensor = zipped_divide(tensor, Tiler_MN{});
|
||||
// ((TileM,TileN,...),(RestM,RestN,...))
|
||||
|
||||
// Transform the tile mode
|
||||
auto tv_tensor = t_tensor.compose(thrval2mn, _);
|
||||
// ((thrid,val),(RM,RN,...))
|
||||
|
||||
// Unfold and return
|
||||
return tv_tensor(make_coord(_,_), _);
|
||||
}
|
||||
|
||||
// retile_S and retile_D assume they are working with the reference layout -- they are the same
|
||||
template <class Tensor>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
retile(Tensor&& tensor)
|
||||
{
|
||||
constexpr int R = remove_cvref_t<Tensor>::rank;
|
||||
// Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation
|
||||
|
||||
// Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV.
|
||||
// Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV
|
||||
// and that shape is what we gather from the other modes of tensor
|
||||
|
||||
auto V = size<0>(tensor);
|
||||
|
||||
auto frg_layout_mn = upcast<TiledNumThr{} * V>(right_inverse(TiledLayout_TV{}).with_shape(TiledShape_MN{}));
|
||||
// (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV
|
||||
|
||||
auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{}));
|
||||
// (atom_vals,rest_vals) -> (v,m,n)
|
||||
|
||||
/// =======
|
||||
|
||||
// Tile the tensor for TileFrg
|
||||
auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V));
|
||||
// ((TileV,TileM,TileN,...),(1,RestM,RestN,...))
|
||||
|
||||
// Transform the tile mode
|
||||
auto v_tensor = t_tensor.compose(frg_layout_v, _);
|
||||
// ((atom_vals,rest_vals),(1,RM,RN,...))
|
||||
|
||||
// Unfold and return
|
||||
return v_tensor(_, append<R>(Int<0>{},_));
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
get_layoutS_MN()
|
||||
{
|
||||
// (M,N) -> (M,N)
|
||||
auto ref_S = make_layout(TiledShape_MN{});
|
||||
// (thr_idx,val_idx) -> (M,N)
|
||||
auto layoutS_TV = tidfrg_S(ref_S);
|
||||
// (M,K) -> (thr_idx,val_idx)
|
||||
auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(ref_S));
|
||||
|
||||
// athrid = (v,m,k) -> thr_idx
|
||||
auto thrID_S = make_layout(size<0>(TiledLayout_TV{}));
|
||||
|
||||
return cute::make_tuple(layoutS_MK, thrID_S);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
get_layoutS_TV()
|
||||
{
|
||||
// (M,N) -> (M,N)
|
||||
auto ref_S = make_layout(TiledShape_MN{});
|
||||
// (thr_idx,val_idx) -> (M,N)
|
||||
return tidfrg_S(ref_S)(_,_,Int<0>{});
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
get_layoutD_MN()
|
||||
{
|
||||
// (M,N) -> (M,N)
|
||||
auto ref_D = make_layout(TiledShape_MN{});
|
||||
// (thr_idx,val_idx) -> (M,N)
|
||||
auto layoutD_TV = tidfrg_D(ref_D);
|
||||
// (M,K) -> (thr_idx,val_idx)
|
||||
auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(ref_D));
|
||||
|
||||
// athrid = (v,m,k) -> thr_idx
|
||||
auto thrID_D = make_layout(size<0>(TiledLayout_TV{}));
|
||||
|
||||
return cute::make_tuple(layoutD_MK, thrID_D);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
get_layoutD_TV()
|
||||
{
|
||||
// (M,N) -> (M,N)
|
||||
auto ref_D = make_layout(TiledShape_MN{});
|
||||
// (thr_idx,val_idx) -> (M,N)
|
||||
return tidfrg_D(ref_D)(_,_,Int<0>{});
|
||||
}
|
||||
|
||||
template <class ThrIdx>
|
||||
struct ThrCopy : Copy_Atom
|
||||
{
|
||||
ThrIdx thr_idx_;
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {}
|
||||
|
||||
template <class STensor>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
partition_S(STensor&& stensor) {
|
||||
//static_assert(sizeof(typename remove_cvref_t<STensor>::value_type) == sizeof(typename Copy_Atom::ValType),
|
||||
// "Expected ValType for tiling SrcTensor.");
|
||||
auto thr_tensor = make_tensor(std::forward<STensor>(stensor).data(), tidfrg_S(stensor.layout()));
|
||||
return thr_tensor(thr_idx_, _, repeat<rank_v<STensor>>(_));
|
||||
}
|
||||
|
||||
template <class DTensor>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
partition_D(DTensor&& dtensor) {
|
||||
//static_assert(sizeof(typename remove_cvref_t<DTensor>::value_type) == sizeof(typename Copy_Atom::ValType),
|
||||
// "Expected ValType for tiling DstTensor.");
|
||||
auto thr_tensor = make_tensor(std::forward<DTensor>(dtensor).data(), tidfrg_D(dtensor.layout()));
|
||||
return thr_tensor(thr_idx_, _, repeat<rank_v<DTensor>>(_));
|
||||
}
|
||||
|
||||
template <class STensor>
|
||||
CUTE_HOST_DEVICE static
|
||||
auto
|
||||
retile_S(STensor&& stensor) {
|
||||
static_assert(sizeof(typename remove_cvref_t<STensor>::value_type) == sizeof(typename Copy_Atom::ValType),
|
||||
"Expected ValType for tiling SrcTensor.");
|
||||
return make_tensor(std::forward<STensor>(stensor).data(), TiledCopy::retile(stensor.layout()));
|
||||
}
|
||||
|
||||
template <class DTensor>
|
||||
CUTE_HOST_DEVICE static
|
||||
auto
|
||||
retile_D(DTensor&& dtensor) {
|
||||
static_assert(sizeof(typename remove_cvref_t<DTensor>::value_type) == sizeof(typename Copy_Atom::ValType),
|
||||
"Expected ValType for tiling DstTensor.");
|
||||
return make_tensor(std::forward<DTensor>(dtensor).data(), TiledCopy::retile(dtensor.layout()));
|
||||
}
|
||||
};
|
||||
|
||||
template <class ThrIdx,
|
||||
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
|
||||
CUTE_HOST_DEVICE static
|
||||
auto
|
||||
get_slice(ThrIdx const& thr_idx)
|
||||
{
|
||||
return ThrCopy<ThrIdx>(thr_idx);
|
||||
}
|
||||
|
||||
template <class ThrIdx,
|
||||
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
|
||||
CUTE_HOST_DEVICE static
|
||||
auto
|
||||
get_thread_slice(ThrIdx const& thr_idx)
|
||||
{
|
||||
return get_slice(thr_idx);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <class... Args,
|
||||
class LayoutCopy_TV,
|
||||
class... TLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_impl(Copy_Atom<Args...> const& atom,
|
||||
LayoutCopy_TV const&,
|
||||
Tile<TLayout...> const&)
|
||||
{
|
||||
return TiledCopy<Copy_Atom<Args...>, LayoutCopy_TV, Tile<TLayout...>>{atom};
|
||||
}
|
||||
|
||||
//
|
||||
// These tile the Copy_Atom as a whole
|
||||
//
|
||||
|
||||
template <class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_A(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma)
|
||||
{
|
||||
using MNK = typename TiledMMA::TiledShape_MNK;
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), make_shape(size<0>(MNK{}),size<2>(MNK{})));
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_B(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma)
|
||||
{
|
||||
using MNK = typename TiledMMA::TiledShape_MNK;
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), make_shape(size<1>(MNK{}),size<2>(MNK{})));
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_C(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma)
|
||||
{
|
||||
using MNK = typename TiledMMA::TiledShape_MNK;
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), make_shape(size<0>(MNK{}),size<1>(MNK{})));
|
||||
}
|
||||
|
||||
template <class... Args,
|
||||
class ThrLayout,
|
||||
class ValLayout = Layout<_1>>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy(Copy_Atom<Args...> const& copy_atom,
|
||||
ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx
|
||||
ValLayout const& val_layout = {})
|
||||
{
|
||||
constexpr int R = cute::max(rank_v<ThrLayout>, rank_v<ValLayout>);
|
||||
|
||||
auto thr_layout_mn = append<R>(thr_layout, Layout<_1>{});
|
||||
auto val_layout_mn = append<R>(val_layout, Layout<_1>{});
|
||||
|
||||
// Take the raked_products to compute the Layout_MN
|
||||
auto layout_mn = raked_product(thr_layout_mn, val_layout_mn);
|
||||
auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout)));
|
||||
|
||||
//print("thr_layout: "); print(thr_layout_mn); print("\n");
|
||||
//print("val_layout: "); print(val_layout_mn); print("\n");
|
||||
//print("layout_mn : "); print(layout_mn); print("\n");
|
||||
//print("layout_tv : "); print(layout_tv); print("\n");
|
||||
|
||||
return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn)));
|
||||
}
|
||||
|
||||
// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy
|
||||
template <class... Args,
|
||||
class TiledCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_S(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledCopy const& tiled_copy)
|
||||
{
|
||||
return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{});
|
||||
}
|
||||
|
||||
// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy
|
||||
template <class... Args,
|
||||
class TiledCopy>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_D(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledCopy const& tiled_copy)
|
||||
{
|
||||
return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{});
|
||||
}
|
||||
|
||||
//
|
||||
// Size
|
||||
//
|
||||
|
||||
// The logical size of a TileCopy
|
||||
template <int... I, class... Args>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tile_size(TiledCopy<Args...> const&)
|
||||
{
|
||||
return size<I...>(typename TiledCopy<Args...>::TiledShape_MN{});
|
||||
}
|
||||
|
||||
// The number of threads involved in a TiledCopy
|
||||
template <class... Args>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
size(TiledCopy<Args...> const&)
|
||||
{
|
||||
return typename TiledCopy<Args...>::TiledNumThr{};
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
template <class... Args>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
print_latex(TiledCopy<Args...> const& copy)
|
||||
{
|
||||
auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN();
|
||||
auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN();
|
||||
|
||||
print_latex_copy(layoutS_MN, thrID_S,
|
||||
layoutD_MN, thrID_D);
|
||||
}
|
||||
|
||||
// MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread
|
||||
template <class LayoutS, class ThrIDS,
|
||||
class LayoutD, class ThrIDD>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx
|
||||
LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{});
|
||||
|
||||
assert(size<0>(S) == size<0>(D));
|
||||
assert(size<1>(S) == size<1>(D));
|
||||
|
||||
char const* latex_header =
|
||||
"\\documentclass{standalone}\n"
|
||||
"\\usepackage{tikz}\n"
|
||||
"\\usetikzlibrary{external}\n"
|
||||
"\\tikzexternalize\n"
|
||||
"\\begin{document}\n"
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n";
|
||||
char const* latex_footer =
|
||||
"\\end{tikzpicture}\n"
|
||||
"\\end{document}\n";
|
||||
|
||||
char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}",
|
||||
"{rgb,255:red,175;green,255;blue,175}",
|
||||
"{rgb,255:red,255;green,255;blue,175}",
|
||||
"{rgb,255:red,255;green,175;blue,175}",
|
||||
"{rgb,255:red,210;green,210;blue,255}",
|
||||
"{rgb,255:red,210;green,255;blue,210}",
|
||||
"{rgb,255:red,255;green,255;blue,210}",
|
||||
"{rgb,255:red,255;green,210;blue,210}",};
|
||||
|
||||
// Header
|
||||
printf("%% LayoutS: "); print(S); printf("\n");
|
||||
printf("%% ThrIDS : "); print(TS); printf("\n");
|
||||
printf("%% LayoutD: "); print(D); printf("\n");
|
||||
printf("%% ThrIDD : "); print(TD); printf("\n\n");
|
||||
|
||||
printf(latex_header);
|
||||
|
||||
// S starting at 0,0
|
||||
for (int i = 0; i < size<0>(S); ++i) {
|
||||
for (int j = 0; j < size<1>(S); ++j) {
|
||||
int thrid = S(i,j) % size(TS);
|
||||
int val_idx = S(i,j) / size(TS);
|
||||
int thr_idx = TS(thrid);
|
||||
|
||||
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color_map[thr_idx % 8],
|
||||
i, j,
|
||||
thr_idx, val_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// D starting at 0,size<1>(S)+3
|
||||
for (int i = 0; i < size<0>(D); ++i) {
|
||||
for (int j = 0; j < size<1>(D); ++j) {
|
||||
int thrid = D(i,j) % size(TD);
|
||||
int val_idx = D(i,j) / size(TD);
|
||||
int thr_idx = TD(thrid);
|
||||
|
||||
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
|
||||
color_map[thr_idx % 8],
|
||||
i, j + size<1>(S) + 3,
|
||||
thr_idx, val_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// S Labels
|
||||
for (int i = 0, j = -1; i < size<0>(S); ++i) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i);
|
||||
}
|
||||
for (int j = 0, i = -1; j < size<1>(S); ++j) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j);
|
||||
}
|
||||
// D Labels
|
||||
for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i);
|
||||
}
|
||||
for (int j = 0, i = -1; j < size<1>(D); ++j) {
|
||||
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j);
|
||||
}
|
||||
|
||||
// Footer
|
||||
printf(latex_footer);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
#include <cute/atom/copy_traits_sm75.hpp>
|
||||
#include <cute/atom/copy_traits_sm80.hpp>
|
||||
#include <cute/atom/copy_traits_sm90.hpp>
|
||||
// Config
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
# define CUTE_COPY_ATOM_TMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
|
||||
#include <cute/atom/copy_traits_sm90_tma.hpp>
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
76
include/cute/atom/copy_traits.hpp
Normal file
76
include/cute/atom/copy_traits.hpp
Normal file
@ -0,0 +1,76 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class CopyOperation, class... CopyOpArgs>
|
||||
struct Copy_Traits
|
||||
{
|
||||
static_assert(sizeof(CopyOperation) == 0, "Copy_Traits not implemented for this Copy_Operation.");
|
||||
};
|
||||
|
||||
template <class S, class D>
|
||||
struct Copy_Traits<UniversalCopy<S,D>>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1,Int<sizeof_bits<S>::value>>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1,Int<sizeof_bits<D>::value>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<DefaultCopy>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1,_1>, Stride<_0,_0>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1,_1>, Stride<_0,_0>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
143
include/cute/atom/copy_traits_sm75.hpp
Normal file
143
include/cute/atom/copy_traits_sm75.hpp
Normal file
@ -0,0 +1,143 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/copy_sm75.hpp>
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM75_U32x1_LDSM_N>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape <Shape < _8,_4>,_128>,
|
||||
Stride<Stride<_128,_0>, _1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape <_32,_32>,
|
||||
Stride<_32, _1>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = DstLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM75_U32x2_LDSM_N>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape <Shape < _16,_2>,_128>,
|
||||
Stride<Stride<_128,_0>, _1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape <_32,Shape <_32, _2>>,
|
||||
Stride<_32,Stride< _1,_1024>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = DstLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM75_U32x4_LDSM_N>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape < _32,_128>,
|
||||
Stride<_128, _1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape <_32,Shape <_32, _4>>,
|
||||
Stride<_32,Stride< _1,_1024>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = DstLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM75_U16x2_LDSM_T>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape <Shape < _8,_4>,_128>,
|
||||
Stride<Stride<_128,_0>, _1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape <Shape < _4, _8>,Shape <_16, _2>>,
|
||||
Stride<Stride<_256,_16>,Stride< _1,_128>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = DstLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM75_U16x4_LDSM_T>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape <Shape < _16,_2>,_128>,
|
||||
Stride<Stride<_128,_0>, _1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape <Shape < _4, _8>,Shape <_16, _2, _2>>,
|
||||
Stride<Stride<_256,_16>,Stride< _1,_128,_1024>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = DstLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM75_U16x8_LDSM_T>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape < _32,_128>,
|
||||
Stride<_128, _1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape <Shape < _4, _8>,Shape <_16, _2, _4>>,
|
||||
Stride<Stride<_256,_16>,Stride< _1,_128,_1024>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = DstLayout;
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
98
include/cute/atom/copy_traits_sm80.hpp
Normal file
98
include/cute/atom/copy_traits_sm80.hpp
Normal file
@ -0,0 +1,98 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/copy_sm80.hpp>
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class S, class D>
|
||||
struct Copy_Traits<SM80_CP_ASYNC_CACHEALWAYS<S,D>>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1,Int<sizeof_bits<S>::value>>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1,Int<sizeof_bits<D>::value>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
template <class S, class D>
|
||||
struct Copy_Traits<SM80_CP_ASYNC_CACHEGLOBAL<S,D>>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1,Int<sizeof_bits<S>::value>>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1,Int<sizeof_bits<D>::value>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Element copy selector
|
||||
template <class SrcTensor, class DstTensor>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
select_elementwise_copy(SrcTensor const&, DstTensor const&)
|
||||
{
|
||||
using SrcType = typename SrcTensor::value_type;
|
||||
using DstType = typename DstTensor::value_type;
|
||||
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
if constexpr (is_gmem<SrcTensor>::value && is_smem<DstTensor>::value &&
|
||||
sizeof(SrcType) == sizeof(DstType) &&
|
||||
(sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16))
|
||||
{
|
||||
return SM80_CP_ASYNC_CACHEALWAYS<SrcType,DstType>{};
|
||||
} else {
|
||||
return UniversalCopy<SrcType,DstType>{};
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
#else
|
||||
return UniversalCopy<SrcType,DstType>{};
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
132
include/cute/atom/copy_traits_sm90.hpp
Normal file
132
include/cute/atom/copy_traits_sm90.hpp
Normal file
@ -0,0 +1,132 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/copy_sm90.hpp>
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
#include <cute/atom/copy_traits_sm75.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_U32x1_STSM_N>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = typename Copy_Traits<SM75_U32x1_LDSM_N>::DstLayout;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = typename Copy_Traits<SM75_U32x1_LDSM_N>::SrcLayout;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_U32x2_STSM_N>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = typename Copy_Traits<SM75_U32x2_LDSM_N>::DstLayout;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = typename Copy_Traits<SM75_U32x2_LDSM_N>::SrcLayout;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_U32x4_STSM_N>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = typename Copy_Traits<SM75_U32x4_LDSM_N>::DstLayout;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = typename Copy_Traits<SM75_U32x4_LDSM_N>::SrcLayout;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_U16x2_STSM_T>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = typename Copy_Traits<SM75_U16x2_LDSM_T>::DstLayout;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = typename Copy_Traits<SM75_U16x2_LDSM_T>::SrcLayout;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_U16x4_STSM_T>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = typename Copy_Traits<SM75_U16x4_LDSM_T>::DstLayout;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = typename Copy_Traits<SM75_U16x4_LDSM_T>::SrcLayout;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_U16x8_STSM_T>
|
||||
{
|
||||
// Logical thread id to thread idx (warp)
|
||||
using ThrID = Layout<_32>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = typename Copy_Traits<SM75_U16x8_LDSM_T>::DstLayout;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = typename Copy_Traits<SM75_U16x8_LDSM_T>::SrcLayout;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
795
include/cute/atom/copy_traits_sm90_tma.hpp
Normal file
795
include/cute/atom/copy_traits_sm90_tma.hpp
Normal file
@ -0,0 +1,795 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////// TMA_LOAD ///////////////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {};
|
||||
|
||||
// The executable SM90_TMA_LOAD with tma_desc and tma_mbar
|
||||
template <class NumBits>
|
||||
struct Copy_Traits<SM90_TMA_LOAD_OP, NumBits>
|
||||
{
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1,NumBits>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1,NumBits>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM90_TMA_LOAD arguments
|
||||
TmaDescriptor const& tma_desc_;
|
||||
uint64_t& tma_load_mbar_;
|
||||
|
||||
template <class Coord, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
copy_unpack_(void const* const dst_ptr,
|
||||
Coord const& src_coord, seq<Is...>) const
|
||||
{
|
||||
#if 0
|
||||
print("THR (%d,%d,%d) BLK (%d,%d,%d)\n",
|
||||
threadIdx.x, threadIdx.y, threadIdx.z,
|
||||
blockIdx.x, blockIdx.y, blockIdx.z);
|
||||
print(" TMA Coord "); print(src_coord); print("\n");
|
||||
print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_),
|
||||
uint64_t(tma_desc_.size1_),
|
||||
uint64_t(tma_desc_.size2_),
|
||||
uint64_t(tma_desc_.size3_))); print("\n");
|
||||
#endif
|
||||
|
||||
SM90_TMA_LOAD::copy(&tma_desc_,
|
||||
tma_load_mbar_,
|
||||
dst_ptr,
|
||||
get<Is>(src_coord)...);
|
||||
}
|
||||
|
||||
// This is the copy_unpack dispatch for this Copy_Traits
|
||||
// Src needs to be a gmem tensor with TmaCoordIterator .data()
|
||||
// Dst needs to be a smem tensor
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE friend constexpr
|
||||
void
|
||||
copy_unpack(Copy_Traits const& traits,
|
||||
Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst)
|
||||
{
|
||||
//static_assert(is_gmem<TS>::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor
|
||||
static_assert(is_smem<TD>::value, "Expected smem dst for SM90_TMA_LOAD");
|
||||
|
||||
traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq<decltype(src.data().coord_)>{});
|
||||
}
|
||||
};
|
||||
|
||||
// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar
|
||||
// Use .with(tma_mbar) to construct an executable version
|
||||
template <class NumBits, class GmemStrides>
|
||||
struct Copy_Traits<SM90_TMA_LOAD, NumBits, GmemStrides>
|
||||
{
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1,NumBits>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1,NumBits>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM90_TMA_LOAD arguments
|
||||
TmaDescriptor tma_desc_;
|
||||
GmemStrides g_stride_;
|
||||
|
||||
// Return TmaDescriptor/TensorMap
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TmaDescriptor const*
|
||||
get_tma_descriptor() const {
|
||||
return &tma_desc_;
|
||||
}
|
||||
|
||||
// Construct an executable SM90_TMA_LOAD with tma_mbar
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM90_TMA_LOAD_OP, NumBits>
|
||||
with(uint64_t& tma_mbar, uint16_t const& multicast_mask = 0) const {
|
||||
// We accept multicast_mask here to keep the API for both atoms consistent
|
||||
// assert(multicast_mask == 0);
|
||||
(void) multicast_mask;
|
||||
return {tma_desc_, tma_mbar};
|
||||
}
|
||||
|
||||
// Generate the TMA coord tensor
|
||||
template <class GShape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_tma_tensor(GShape const& g_shape) const {
|
||||
static_assert(is_congruent<decltype(g_shape), decltype(g_stride_)>::value);
|
||||
constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value;
|
||||
return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat<tma_rank>(Int<0>{}))),
|
||||
g_shape,
|
||||
g_stride_);
|
||||
}
|
||||
|
||||
// Don't try to execute a copy with SM90_TMA_LOAD before calling .with()
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE friend constexpr void
|
||||
copy_unpack(Copy_Traits const& traits,
|
||||
Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst) = delete;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////// TMA_LOAD_MULTICAST /////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {};
|
||||
|
||||
template <class NumBits>
|
||||
struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBits>
|
||||
{
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1,NumBits>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1,NumBits>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM90_TMA_LOAD_MULTICAST arguments
|
||||
TmaDescriptor const& tma_desc_;
|
||||
uint64_t& tma_load_mbar_;
|
||||
uint16_t const& multicast_mask_;
|
||||
|
||||
template <class Coord, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
copy_unpack_(void const* const dst_ptr,
|
||||
Coord const& src_coord, seq<Is...>) const
|
||||
{
|
||||
#if 0
|
||||
print("THR (%d,%d,%d) BLK (%d,%d,%d)\n",
|
||||
threadIdx.x, threadIdx.y, threadIdx.z,
|
||||
blockIdx.x, blockIdx.y, blockIdx.z);
|
||||
print(" TMA Coord "); print(src_coord); print("\n");
|
||||
print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_),
|
||||
uint64_t(tma_desc_.size1_),
|
||||
uint64_t(tma_desc_.size2_),
|
||||
uint64_t(tma_desc_.size3_))); print("\n");
|
||||
#endif
|
||||
|
||||
SM90_TMA_LOAD_MULTICAST::copy(&tma_desc_,
|
||||
tma_load_mbar_,
|
||||
multicast_mask_,
|
||||
dst_ptr,
|
||||
get<Is>(src_coord)...);
|
||||
}
|
||||
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE friend constexpr
|
||||
void
|
||||
copy_unpack(Copy_Traits const& traits,
|
||||
Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst)
|
||||
{
|
||||
//static_assert(is_gmem<TS>::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor
|
||||
static_assert(is_smem<TD>::value, "Expected smem dst for SM90_TMA_LOAD_MULTICAST");
|
||||
|
||||
traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq<decltype(src.data().coord_)>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class NumBits, class GmemStrides>
|
||||
struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBits, GmemStrides>
|
||||
{
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1,NumBits>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1,NumBits>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM90_TMA_LOAD_MULTICAST arguments
|
||||
TmaDescriptor tma_desc_;
|
||||
GmemStrides g_stride_;
|
||||
|
||||
// Return TmaDescriptor/TensorMap
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TmaDescriptor const*
|
||||
get_tma_descriptor() const {
|
||||
return &tma_desc_;
|
||||
}
|
||||
|
||||
// Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBits>
|
||||
with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const {
|
||||
return {tma_desc_, tma_load_mbar, multicast_mask};
|
||||
}
|
||||
|
||||
// Generate the TMA coord tensor
|
||||
template <class GShape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_tma_tensor(GShape const& g_shape) const {
|
||||
static_assert(is_congruent<decltype(g_shape), decltype(g_stride_)>::value);
|
||||
constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value;
|
||||
return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat<tma_rank>(Int<0>{}))),
|
||||
g_shape,
|
||||
g_stride_);
|
||||
}
|
||||
|
||||
// Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with()
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE friend constexpr void
|
||||
copy_unpack(Copy_Traits const& traits,
|
||||
Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst) = delete;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////// TMA_STORE //////////////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The executable SM90_TMA_STORE with tma_desc
|
||||
template <class NumBits, class GmemStrides>
|
||||
struct Copy_Traits<SM90_TMA_STORE, NumBits, GmemStrides>
|
||||
{
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1,NumBits>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1,NumBits>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM90_TMA_STORE arguments
|
||||
TmaDescriptor tma_desc_;
|
||||
GmemStrides g_stride_;
|
||||
|
||||
// Generate the TMA coord tensor
|
||||
template <class GShape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_tma_tensor(GShape const& g_shape) const {
|
||||
static_assert(is_congruent<decltype(g_shape), decltype(g_stride_)>::value);
|
||||
constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value;
|
||||
return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat<tma_rank>(Int<0>{}))),
|
||||
g_shape,
|
||||
g_stride_);
|
||||
}
|
||||
|
||||
template <class Coord, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
copy_unpack_(void const* const src_ptr,
|
||||
Coord const& dst_coord, seq<Is...>) const
|
||||
{
|
||||
#if 0
|
||||
print("THR (%d,%d,%d) BLK (%d,%d,%d)\n",
|
||||
threadIdx.x, threadIdx.y, threadIdx.z,
|
||||
blockIdx.x, blockIdx.y, blockIdx.z);
|
||||
print(" TMA Coord "); print(dst_coord); print("\n");
|
||||
print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_),
|
||||
uint64_t(tma_desc_.size1_),
|
||||
uint64_t(tma_desc_.size2_),
|
||||
uint64_t(tma_desc_.size3_))); print("\n");
|
||||
#endif
|
||||
|
||||
SM90_TMA_STORE::copy(&tma_desc_,
|
||||
src_ptr,
|
||||
get<Is>(dst_coord)...);
|
||||
}
|
||||
|
||||
// This is the copy_unpack dispatch for this Copy_Traits
|
||||
// Src needs to be a smem tensor
|
||||
// Dst needs to be a gmem tensor with TmaCoordIterator .data()
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE friend constexpr
|
||||
void
|
||||
copy_unpack(Copy_Traits const& traits,
|
||||
Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst)
|
||||
{
|
||||
static_assert(is_smem<TS>::value, "Expected smem src for SM90_TMA_STORE");
|
||||
//static_assert(is_gmem<TD>::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor
|
||||
|
||||
traits.copy_unpack_(src.data().get(), dst.data().coord_, tuple_seq<decltype(dst.data().coord_)>{});
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// MAKE_TMA_COPY and related
|
||||
//
|
||||
|
||||
template <int B, int M, int S, class Offset, class SLayout>
|
||||
TMA::SmemSwizzleBits
|
||||
get_tma_swizzle_bits(ComposedLayout<Swizzle<B,M,S>,Offset,SLayout>)
|
||||
{
|
||||
static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle.");
|
||||
static_assert(S == 3, "Unsupported layout swizzle");
|
||||
|
||||
switch (B) {
|
||||
default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3. Unsupported layout swizzle.");
|
||||
case 3: return TMA::SmemSwizzleBits::B128;
|
||||
case 2: return TMA::SmemSwizzleBits::B64;
|
||||
case 1: return TMA::SmemSwizzleBits::B32;
|
||||
case 0: return TMA::SmemSwizzleBits::DISABLE;
|
||||
}
|
||||
}
|
||||
|
||||
template <class Shape, class Stride>
|
||||
TMA::SmemSwizzleBits
|
||||
get_tma_swizzle_bits(Layout<Shape,Stride>)
|
||||
{
|
||||
return TMA::SmemSwizzleBits::DISABLE;
|
||||
}
|
||||
|
||||
template <int B, int M, int S, class Offset, class SLayout>
|
||||
auto
|
||||
get_nonswizzle_layout(ComposedLayout<Swizzle<B,M,S>,Offset,SLayout> const& slayout)
|
||||
{
|
||||
return slayout.layout_fn();
|
||||
}
|
||||
|
||||
template <class Shape, class Stride>
|
||||
auto
|
||||
get_nonswizzle_layout(Layout<Shape,Stride> const& slayout)
|
||||
{
|
||||
return slayout;
|
||||
}
|
||||
|
||||
/** Make a CuTe CTA-collective TiledCopy for a TMA operation.
|
||||
*
|
||||
* @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE
|
||||
* @param gtensor The GMEM Tensor to be involved in the TMA.
|
||||
* @param slayout The SMEM Layout to be involved in the TMA.
|
||||
* @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with.
|
||||
* This is often the blk_shape that is used to tile the GMEM for CTAs:
|
||||
* local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor
|
||||
* @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16
|
||||
* defining the multicast size (used to further partition the SMEM)
|
||||
* Else, static-1
|
||||
*
|
||||
* This code attempts to maximize the TMA box size. It does this by tracing
|
||||
* the SMEM "vector" -- the inverse of the smem layout -- to find the largest
|
||||
* contiguous array of smem that can be written to/from global memory given
|
||||
* the constraints that the TMA instruction imposes.
|
||||
*
|
||||
* This is accomplished by assigning "basis" strides to the GMEM to track which
|
||||
* modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according
|
||||
* to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc.
|
||||
*
|
||||
* Examples:
|
||||
using T = float;
|
||||
T* gptr = nullptr;
|
||||
|
||||
{
|
||||
// Simple 2D
|
||||
Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM
|
||||
auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM
|
||||
auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout);
|
||||
}
|
||||
|
||||
{
|
||||
// GMMA 2D
|
||||
Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM
|
||||
auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom<T>{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM
|
||||
auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout);
|
||||
}
|
||||
|
||||
{
|
||||
// 3D
|
||||
Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM
|
||||
auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode
|
||||
auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout);
|
||||
}
|
||||
|
||||
{
|
||||
// cuTENSOR 4D
|
||||
auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM
|
||||
auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling:
|
||||
// Take 128-elem from m: m0 must divide 128,
|
||||
// m-last may be predicated
|
||||
// Take 32-elem from k0, 2-elem from k1
|
||||
auto slayout = make_layout(cta_tile); // Col-Major SMEM
|
||||
auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{});
|
||||
}
|
||||
*
|
||||
* Check the TMA box size and desc:
|
||||
print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n");
|
||||
print("TMA desc : "); print(tma.tma_desc_); print("\n");
|
||||
*
|
||||
* Usage:
|
||||
Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor
|
||||
Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA
|
||||
Tensor sA = make_tensor(make_smem_ptr<T>(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor
|
||||
|
||||
auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning
|
||||
Tensor tAgA = cta_tma.partition_S(gA); // Partition for src
|
||||
Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst
|
||||
|
||||
copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params
|
||||
*/
|
||||
template <class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class CTA_Tile,
|
||||
class Cluster_Size>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_tma_copy(CopyOp,
|
||||
Tensor<GEngine,GLayout> const& gtensor,
|
||||
SLayout const& slayout,
|
||||
CTA_Tile const& cta_tile,
|
||||
Cluster_Size const& cluster_size)
|
||||
{
|
||||
static_assert((std::is_same<CopyOp, SM90_TMA_LOAD>::value && is_constant<1, Cluster_Size>::value) ||
|
||||
(std::is_same<CopyOp, SM90_TMA_LOAD_MULTICAST>::value) ||
|
||||
(std::is_same<CopyOp, SM90_TMA_STORE>::value && is_constant<1, Cluster_Size>::value));
|
||||
|
||||
using T = typename Tensor<GEngine,GLayout>::value_type;
|
||||
|
||||
//
|
||||
// TMA parameter checking
|
||||
//
|
||||
|
||||
auto flat_glayout = flatten(gtensor.layout());
|
||||
|
||||
CUTE_STATIC_ASSERT_V(rank(flatten(cta_tile)) <= Int<5>{},
|
||||
"CTA_Tile cannot have more than five modes, TMA arch restriction.");
|
||||
CUTE_STATIC_ASSERT_V(rank(flat_glayout) <= Int<5>{} || rank(flatten(cta_tile)) <= Int<4>{},
|
||||
"If GTensor has more than five modes, then CTA_Tile cannot have more than four modes. TMA multimode.");
|
||||
CUTE_STATIC_ASSERT_V(compatible(product_each(shape(slayout)), shape(cta_tile)),
|
||||
"CTA_Tile must be compatible with SLayout.");
|
||||
CUTE_STATIC_ASSERT_V(is_integral<Cluster_Size>{} && has_single_bit(cluster_size) && cluster_size <= Int<16>{},
|
||||
"Expecting a pow2 integral Cluster_Size leq 16.");
|
||||
CUTE_STATIC_ASSERT_V(size(slayout) % cluster_size == Int<0>{},
|
||||
"ClusterShape must divide domain size of slayout.");
|
||||
|
||||
//
|
||||
// TMA slayout manipulation
|
||||
//
|
||||
|
||||
auto tma_multimode = rank(flat_glayout) > Int<5>{};
|
||||
|
||||
// Invert the smem to get the largest contiguous vector in the smem layout
|
||||
auto inv_smem_layout = right_inverse(get_nonswizzle_layout(slayout));
|
||||
// trunc_smem_idx -> trunc_smem_coord
|
||||
|
||||
// Map from smem idx to a gmem mode
|
||||
auto sidx_to_gmode = flatten(composition(make_identity_layout(cta_tile), inv_smem_layout));
|
||||
|
||||
// Truncate any incompatibilities
|
||||
auto smem_rank = find_if(stride(sidx_to_gmode), [](auto e){
|
||||
[[maybe_unused]] auto v = basis_value(e);
|
||||
return not is_constant<1,decltype(v)>{};
|
||||
});
|
||||
static_assert(smem_rank > 0, "Could not find a common smem-gmem vectorization for TMA.");
|
||||
constexpr int smem_tma_rank = cute::min(int(smem_rank), (tma_multimode ? 4 : 5));
|
||||
|
||||
// Keep only the static-1 basis modes into gmem
|
||||
auto sidx_to_gmode_cluster_trunc = take<0,smem_tma_rank>(sidx_to_gmode);
|
||||
// Keep only the portion each multicast CTA will be responsible for
|
||||
auto sidx_to_gmode_cta_trunc = composition(sidx_to_gmode_cluster_trunc, shape_div(size(sidx_to_gmode_cluster_trunc), cluster_size));
|
||||
|
||||
//
|
||||
// TMA gtensor manipulation
|
||||
//
|
||||
|
||||
// Generate a TupleBasis for the gtensor
|
||||
auto flat_gbasis = make_basis_like(shape(flat_glayout));
|
||||
|
||||
// Fold the flat_gbasis into the glayout
|
||||
auto glayout_basis = make_layout(shape(gtensor),
|
||||
stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), flat_gbasis),
|
||||
make_layout(repeat_like(shape(gtensor), Int<2>{})))));
|
||||
|
||||
// Tile the modes of gtensor with cta_tile
|
||||
auto cta_glayout_basis = composition(glayout_basis, cta_tile);
|
||||
|
||||
// Check that the cta_tile selects modes from gtensor properly
|
||||
for_each(flatten(stride(cta_glayout_basis)), [](auto d) {
|
||||
static_assert(is_constant<1, decltype(d.value())>::value,
|
||||
"CTA_Tile does not faithfully partition the GMEM, it should select the number of elements from each mode of glayout.");
|
||||
});
|
||||
|
||||
// Tile the modes of gtensor again with the truncated cta_tile o inv_smem_layout
|
||||
auto tma_layout_cta_trunc = flatten(composition(glayout_basis, sidx_to_gmode_cta_trunc));
|
||||
|
||||
// Append any missing basis on the end as size-1 modes b/c they got truncated
|
||||
auto missing_basis = fold(stride(tma_layout_cta_trunc), flat_gbasis, [](auto init, auto e){
|
||||
auto k = find(init, e);
|
||||
return remove<k>(init);
|
||||
});
|
||||
|
||||
// The appended map from truncated smem codomain to gmem mode: trunc_smem_idx -> gmem_mode
|
||||
auto tma_layout_cta = flatten(make_layout(tma_layout_cta_trunc,
|
||||
make_layout(repeat<rank(missing_basis)>(Int<1>{}), missing_basis)));
|
||||
|
||||
#if 0
|
||||
print("g_layout : "); print(gtensor.layout()); print("\n");
|
||||
print("s_layout : "); print(slayout); print("\n");
|
||||
print("cta_tile : "); print(cta_tile); print("\n");
|
||||
print("cluster_size : "); print(cluster_size); print("\n");
|
||||
print("flat_gbasis : "); print(flat_gbasis); print("\n");
|
||||
print("cta_glayout : "); print(cta_glayout_basis); print("\n");
|
||||
print("inv_smem : "); print(inv_smem_layout); print("\n");
|
||||
print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n");
|
||||
print("missing_b : "); print(missing_basis); print("\n");
|
||||
print("tma_layout_cta: "); print(tma_layout_cta); print("\n");
|
||||
#endif
|
||||
|
||||
//
|
||||
// TMA gmem desc info
|
||||
//
|
||||
|
||||
constexpr int TmaRANK = cute::min(rank(flat_glayout), 5);
|
||||
void* gmem_address = (void*) gtensor.data();
|
||||
|
||||
cute::array<cuuint64_t, 5> gmem_prob_shape = {1,1,1,1,1};
|
||||
cute::array<cuuint64_t, 5> gmem_prob_stride = {0,0,0,0,0};
|
||||
for_each(make_seq<rank(tma_layout_cta)>{}, [&](auto i) {
|
||||
// NOTE : WAR g++-7.3.5, let it deduce e rather than fuse with below
|
||||
auto e = stride<i>(tma_layout_cta);
|
||||
constexpr int j = decltype(e.mode())::value;
|
||||
constexpr int tma_i = i < 5 ? i : 4;
|
||||
|
||||
// Problem stride
|
||||
uint64_t stride_j = stride<j>(flat_glayout) * sizeof(T);
|
||||
uint64_t old_stride = gmem_prob_stride[tma_i];
|
||||
gmem_prob_stride[tma_i] = gcd(gmem_prob_stride[tma_i], stride_j);
|
||||
|
||||
// Problem shape
|
||||
uint64_t shape_j = shape<j>(flat_glayout);
|
||||
if (gmem_prob_stride[tma_i] != 0) {
|
||||
// We're "resetting" this TMA mode and using it as a "multimode"
|
||||
// Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1
|
||||
gmem_prob_shape[tma_i] = (gmem_prob_shape[tma_i]-1) * (old_stride / gmem_prob_stride[tma_i])
|
||||
+ (shape_j-1) * (stride_j / gmem_prob_stride[tma_i])
|
||||
+ 1;
|
||||
} else {
|
||||
gmem_prob_shape[tma_i] = shape_j;
|
||||
}
|
||||
});
|
||||
|
||||
assert((reinterpret_cast<uint64_t>(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned
|
||||
|
||||
assert(gmem_prob_shape[0] >= (uint64_t(1))); // Size must be min 1
|
||||
assert(gmem_prob_shape[0] <= (uint64_t(1) << 32)); // Size must be max 2^32
|
||||
assert(gmem_prob_shape[1] >= (uint64_t(1))); // Size must be min 1
|
||||
assert(gmem_prob_shape[1] <= (uint64_t(1) << 32)); // Size must be max 2^32
|
||||
assert(gmem_prob_shape[2] >= (uint64_t(1))); // Size must be min 1
|
||||
assert(gmem_prob_shape[2] <= (uint64_t(1) << 32)); // Size must be max 2^32
|
||||
assert(gmem_prob_shape[3] >= (uint64_t(1))); // Size must be min 1
|
||||
assert(gmem_prob_shape[3] <= (uint64_t(1) << 32)); // Size must be max 2^32
|
||||
assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1
|
||||
assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32
|
||||
|
||||
assert((gmem_prob_stride[0]) == sizeof(T)); // First stride is implicitly 1
|
||||
assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40
|
||||
assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
|
||||
assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40
|
||||
assert((gmem_prob_stride[2] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
|
||||
assert((gmem_prob_stride[3]) < (uint64_t(1) << 40)); // Stride must be max 2^40
|
||||
assert((gmem_prob_stride[3] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
|
||||
assert((gmem_prob_stride[4]) < (uint64_t(1) << 40)); // Stride must be max 2^40
|
||||
assert((gmem_prob_stride[4] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
|
||||
|
||||
//
|
||||
// TMA smem desc info
|
||||
//
|
||||
|
||||
// TMA smem box size
|
||||
cute::array<cuuint32_t, 5> smem_box_shape = {1,1,1,1,1};
|
||||
for_each(make_seq<rank(tma_layout_cta)>{}, [&](auto i) {
|
||||
uint32_t shape_i = shape<i>(tma_layout_cta);
|
||||
constexpr int tma_i = i < 5 ? i : 4;
|
||||
if (tma_multimode && tma_i == 4) {
|
||||
// We're "reusing" this TMA mode and using it as a "multimode"
|
||||
smem_box_shape[tma_i] = 1;
|
||||
} else {
|
||||
smem_box_shape[tma_i] = shape_i;
|
||||
}
|
||||
});
|
||||
|
||||
// TMA smem mode strides
|
||||
[[maybe_unused]] cute::array<cuuint32_t, 5> smem_box_stride = {1,1,1,1,1};
|
||||
|
||||
assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1
|
||||
assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8
|
||||
assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1
|
||||
assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8
|
||||
assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1
|
||||
assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8
|
||||
assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1
|
||||
assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8
|
||||
|
||||
assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1
|
||||
assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3
|
||||
assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1
|
||||
assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3
|
||||
assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1
|
||||
assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3
|
||||
assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1
|
||||
assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3
|
||||
assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1
|
||||
assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3
|
||||
|
||||
//
|
||||
// Construct the descriptor
|
||||
//
|
||||
|
||||
TmaDescriptor tma_desc = {0};
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
|
||||
//
|
||||
// TMA general info
|
||||
//
|
||||
|
||||
cuuint32_t tma_dim = TmaRANK;
|
||||
CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType<T>();
|
||||
CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE;
|
||||
CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE;
|
||||
CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
|
||||
|
||||
// TMA smem swizzle type
|
||||
CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(slayout));
|
||||
|
||||
CUresult result = cuTensorMapEncodeTiled(
|
||||
&tma_desc,
|
||||
tma_format,
|
||||
tma_dim,
|
||||
gmem_address,
|
||||
gmem_prob_shape.data(),
|
||||
gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1
|
||||
smem_box_shape.data(),
|
||||
smem_box_stride.data(),
|
||||
tma_interleave,
|
||||
smem_swizzle,
|
||||
tma_l2Promotion,
|
||||
tma_oobFill);
|
||||
|
||||
if (result != CUDA_SUCCESS) {
|
||||
std::cerr << "TMA Desc Addr: " << &tma_desc
|
||||
<< "\nformat " << tma_format
|
||||
<< "\ndim " << tma_dim
|
||||
<< "\ngmem_address " << gmem_address
|
||||
<< "\nglobalDim " << gmem_prob_shape
|
||||
<< "\nglobalStrides " << gmem_prob_stride
|
||||
<< "\nboxDim " << smem_box_shape
|
||||
<< "\nelementStrides " << smem_box_stride
|
||||
<< "\ninterleave " << tma_interleave
|
||||
<< "\nswizzle " << smem_swizzle
|
||||
<< "\nl2Promotion " << tma_l2Promotion
|
||||
<< "\noobFill " << tma_oobFill << std::endl;
|
||||
std::cerr << "Error: Failed to intialize the TMA descriptor " << result << std::endl;
|
||||
assert(false);
|
||||
}
|
||||
#endif // (__CUDACC_VER_MAJOR__ >= 12)
|
||||
|
||||
//
|
||||
// Construct the Copy_Traits
|
||||
//
|
||||
|
||||
// Finally, get the inverse permutation of the E<i> bases for the mocked gmem stride
|
||||
auto gmem_stride_bases_flat = transform(make_seq<rank(tma_layout_cta)>{}, [&](auto i) {
|
||||
auto k = find(stride(tma_layout_cta), E<i>{});
|
||||
// NOTE: gcc 7.3.5 WAR -- avoid if constexpr
|
||||
int32_t tma_coord_stride = int32_t(stride<i>(flat_glayout) * sizeof(T) / (gmem_prob_stride[4] != 0 ? gmem_prob_stride[4] : 16));
|
||||
return conditional_return(tma_multimode && (k >= Int<4>{}),
|
||||
E<4>{} * tma_coord_stride, // The 4th TMA mode is the multimode, use int32_t coord stride
|
||||
E<k>{});
|
||||
});
|
||||
|
||||
// Give that the profile of gtensor and fold it
|
||||
auto gmem_stride_bases = stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), gmem_stride_bases_flat),
|
||||
make_layout(repeat_like(shape(gtensor), Int<2>{}))));
|
||||
|
||||
constexpr int num_bits = size(sidx_to_gmode_cta_trunc) * sizeof(T) * 8;
|
||||
using Traits = Copy_Traits<CopyOp, Int<num_bits>, decltype(gmem_stride_bases)>;
|
||||
|
||||
#if 0
|
||||
print("num_bits : "); print(num_bits); print("\n");
|
||||
print("g_stride_bases: "); print(gmem_stride_bases); print("\n");
|
||||
#endif
|
||||
|
||||
//
|
||||
// Construct the TiledCopy
|
||||
//
|
||||
|
||||
// The ThrVal layout for 1 TMA instruction within cta_tile
|
||||
auto layout_tv_1 = composition(inv_smem_layout, make_layout(make_shape(cluster_size, size(sidx_to_gmode_cta_trunc)), GenRowMajor{}));
|
||||
// The ThrVal layout for N TMA instructions within cta_tile
|
||||
auto layout_tv = tile_to_shape(layout_tv_1, make_shape(cluster_size, size(cta_tile)/cluster_size));
|
||||
|
||||
#if 0
|
||||
print("layout_tv : "); print(layout_tv); print("\n");
|
||||
#endif
|
||||
|
||||
return TiledCopy<Copy_Atom<Traits,T>, decltype(layout_tv), decltype(cta_tile)>{tma_desc, gmem_stride_bases};
|
||||
}
|
||||
|
||||
// Explicit defaulting
|
||||
template <class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_tma_copy(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor,
|
||||
SLayout const& slayout)
|
||||
{
|
||||
return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{});
|
||||
}
|
||||
|
||||
template <class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class Cluster_Size>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_tma_copy(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor,
|
||||
SLayout const& slayout,
|
||||
Cluster_Size const& cluster_size)
|
||||
{
|
||||
return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
1081
include/cute/atom/mma_atom.hpp
Normal file
1081
include/cute/atom/mma_atom.hpp
Normal file
File diff suppressed because it is too large
Load Diff
70
include/cute/atom/mma_traits.hpp
Normal file
70
include/cute/atom/mma_traits.hpp
Normal file
@ -0,0 +1,70 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class MMAOperation, class... MMAOpArgs>
|
||||
struct MMA_Traits
|
||||
{
|
||||
static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation.");
|
||||
};
|
||||
|
||||
template <class D, class A, class B, class C>
|
||||
struct MMA_Traits<UniversalFMA<D,A,B,C>>
|
||||
{
|
||||
using ElementDVal = D;
|
||||
using ElementAVal = A;
|
||||
using ElementBVal = B;
|
||||
using ElementCVal = C;
|
||||
|
||||
// Logical shape of the MMA
|
||||
using Shape_MNK = Shape<_1,_1,_1>;
|
||||
|
||||
// Logical thread id (tid) -> tidx
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// (Logical thread id (tid), Logical value id (vid)) -> coord
|
||||
|
||||
// (tid,vid) -> (m,k)
|
||||
using ALayout = Layout<Shape<_1,_1>>;
|
||||
// (tid,vid) -> (n,k)
|
||||
using BLayout = Layout<Shape<_1,_1>>;
|
||||
// (tid,vid) -> (m,n)
|
||||
using CLayout = Layout<Shape<_1,_1>>;
|
||||
};
|
||||
|
||||
} // namespace cute
|
||||
73
include/cute/atom/mma_traits_sm61.hpp
Normal file
73
include/cute/atom/mma_traits_sm61.hpp
Normal file
@ -0,0 +1,73 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/mma_sm61.hpp>
|
||||
|
||||
#include <cute/atom/mma_traits.hpp>
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM61_DP4A>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_1,_1,_4>;
|
||||
using ThrID = Layout<_1>;
|
||||
using ALayout = Layout<Shape<_1,_4>>;
|
||||
using BLayout = Layout<Shape<_1,_4>>;
|
||||
using CLayout = Layout<Shape<_1,_1>>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM61_DP2A>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int16_t;
|
||||
using ElementBVal = int16_t;
|
||||
using ElementCVal = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_1,_1,_2>;
|
||||
using ThrID = Layout<_1>;
|
||||
using ALayout = Layout<Shape<_1,_2>>;
|
||||
using BLayout = Layout<Shape<_1,_2>>;
|
||||
using CLayout = Layout<Shape<_1,_1>>;
|
||||
};
|
||||
|
||||
} // namespace cute
|
||||
198
include/cute/atom/mma_traits_sm70.hpp
Normal file
198
include/cute/atom/mma_traits_sm70.hpp
Normal file
@ -0,0 +1,198 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/mma_sm70.hpp>
|
||||
|
||||
#include <cute/atom/mma_traits.hpp>
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
namespace {
|
||||
|
||||
// Logical thread id to thread idx (quadpair)
|
||||
using SM70_QuadPair = Layout<Shape <_4, _2>,
|
||||
Stride<_1,_16>>;
|
||||
// (T8,V4) -> (M8,K4)
|
||||
using SM70_8x4_Row = Layout<Shape <_8,_4>,
|
||||
Stride<_1,_8>>;
|
||||
// (T8,V4) -> (M8,K4)
|
||||
using SM70_8x4_Col = Layout<Shape <Shape <_4,_2>,_4>,
|
||||
Stride<Stride<_8,_4>,_1>>;
|
||||
// (T8,V8) -> (M8,N8)
|
||||
using SM70_8x8_16b = Layout<Shape <_8,_8>,
|
||||
Stride<_1,_8>>;
|
||||
// (T8,V8) -> (M8,N8)
|
||||
using SM70_8x8_32b = Layout<Shape <Shape <_2, _2,_2>,Shape <_2,_2, _2>>,
|
||||
Stride<Stride<_1,_16,_4>,Stride<_8,_2,_32>>>;
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
using ALayout = SM70_8x4_Row;
|
||||
using BLayout = SM70_8x4_Row;
|
||||
using CLayout = SM70_8x8_16b;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NT>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
using ALayout = SM70_8x4_Col;
|
||||
using BLayout = SM70_8x4_Col;
|
||||
using CLayout = SM70_8x8_16b;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_NN>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
using ALayout = SM70_8x4_Col;
|
||||
using BLayout = SM70_8x4_Row;
|
||||
using CLayout = SM70_8x8_16b;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F16F16F16F16_TT>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
using ALayout = SM70_8x4_Row;
|
||||
using BLayout = SM70_8x4_Col;
|
||||
using CLayout = SM70_8x8_16b;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
using ALayout = SM70_8x4_Row;
|
||||
using BLayout = SM70_8x4_Row;
|
||||
using CLayout = SM70_8x8_32b;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NT>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
using ALayout = SM70_8x4_Col;
|
||||
using BLayout = SM70_8x4_Col;
|
||||
using CLayout = SM70_8x8_32b;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
using ALayout = SM70_8x4_Col;
|
||||
using BLayout = SM70_8x4_Row;
|
||||
using CLayout = SM70_8x8_32b;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_TT>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = SM70_QuadPair;
|
||||
using ALayout = SM70_8x4_Row;
|
||||
using BLayout = SM70_8x4_Col;
|
||||
using CLayout = SM70_8x8_32b;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace cute
|
||||
81
include/cute/atom/mma_traits_sm75.hpp
Normal file
81
include/cute/atom/mma_traits_sm75.hpp
Normal file
@ -0,0 +1,81 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/mma_sm75.hpp>
|
||||
|
||||
#include <cute/atom/mma_traits.hpp>
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM75_16x8x8_F32F16F16F32_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_8>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
|
||||
Stride<Stride<_32,_2>,Stride<_16,_1>>>;
|
||||
using BLayout = Layout<Shape <Shape < _4,_8>,_2>,
|
||||
Stride<Stride<_16,_1>,_8>>;
|
||||
using CLayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
|
||||
Stride<Stride<_32,_2>,Stride<_16,_1>>>;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM75_8x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_16>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <Shape < _4,_8>,_4>,
|
||||
Stride<Stride<_32,_1>,_8>>;
|
||||
using BLayout = Layout<Shape <Shape < _4,_8>,_4>,
|
||||
Stride<Stride<_32,_1>,_8>>;
|
||||
using CLayout = Layout<Shape <Shape < _4,_8>,_2>,
|
||||
Stride<Stride<_16,_1>,_8>>;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
446
include/cute/atom/mma_traits_sm80.hpp
Normal file
446
include/cute/atom/mma_traits_sm80.hpp
Normal file
@ -0,0 +1,446 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/mma_sm80.hpp>
|
||||
#include <cute/atom/mma_traits.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
#include <cute/numeric/integer_subbyte.hpp>
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
namespace {
|
||||
|
||||
// (T32,V1) -> (M8,N8)
|
||||
using SM80_8x4 = Layout<Shape <Shape < _4,_8>,_1>,
|
||||
Stride<Stride< _8,_1>,_0>>;
|
||||
// (T32,V2) -> (M8,N8)
|
||||
using SM80_8x8_Row = Layout<Shape <Shape < _4,_8>,_2>,
|
||||
Stride<Stride<_16,_1>,_8>>;
|
||||
// (T32,V4) -> (M8,N16)
|
||||
using SM80_8x16_Row = Layout<Shape <Shape < _4,_8>,_4>,
|
||||
Stride<Stride<_32,_1>,_8>>;
|
||||
// (T32,V4) -> (M16,N8)
|
||||
using SM80_16x8_Row = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
|
||||
Stride<Stride<_32,_1>,Stride<_16,_8>>>;
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////// fp16 = fp16 * fp16 + fp16 ////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_8>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = SM80_16x8_Row;
|
||||
using BLayout = SM80_8x8_Row;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = half_t;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = half_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_16>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2, _2>>,
|
||||
Stride<Stride<_32,_1>,Stride<_16,_8,_128>>>;
|
||||
using BLayout = Layout<Shape <Shape < _4,_8>,Shape <_2, _2>>,
|
||||
Stride<Stride<_16,_1>,Stride<_8,_64>>>;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////// fp32 = fp16 * fp16 + fp32 ////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x8_F32F16F16F32_TN>
|
||||
: MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_F32F16F16F32_TN>
|
||||
: MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = half_t;
|
||||
using ElementBVal = half_t;
|
||||
using ElementCVal = float;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////// fp32 = bf16 * bf16 + fp32 ////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x8_F32BF16BF16F32_TN>
|
||||
: MMA_Traits<SM80_16x8x8_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = bfloat16_t;
|
||||
using ElementBVal = bfloat16_t;
|
||||
using ElementCVal = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
: MMA_Traits<SM80_16x8x16_F16F16F16F16_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = bfloat16_t;
|
||||
using ElementBVal = bfloat16_t;
|
||||
using ElementCVal = float;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////// fp32 = tf32 * tf32 + fp32 ////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x4_F32TF32TF32F32_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = cutlass::tfloat32_t;
|
||||
using ElementBVal = cutlass::tfloat32_t;
|
||||
using ElementCVal = float;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_4>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <Shape < _4,_8>,_2>,
|
||||
Stride<Stride<_16,_1>,_8>>;
|
||||
using BLayout = SM80_8x4;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x8_F32TF32TF32F32_TN>
|
||||
{
|
||||
using ElementDVal = float;
|
||||
using ElementAVal = cutlass::tfloat32_t;
|
||||
using ElementBVal = cutlass::tfloat32_t;
|
||||
using ElementCVal = float;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_8>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <Shape < _4,_8>,Shape <_2, _2>>,
|
||||
Stride<Stride<_16,_1>,Stride<_8,_64>>>;
|
||||
using BLayout = Layout<Shape <Shape <_4,_8>, _2>,
|
||||
Stride<Stride<_8,_1>,_32>>;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////// fp64 = fp64 * fp64 + fp64 ////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = double;
|
||||
using ElementAVal = double;
|
||||
using ElementBVal = double;
|
||||
using ElementCVal = double;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_4>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = SM80_8x4;
|
||||
using BLayout = SM80_8x4;
|
||||
using CLayout = SM80_8x8_Row;
|
||||
};
|
||||
|
||||
// Custom complex fp64 MMA composed of 4 fp64 MMAs -- same layouts
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x4_C64C64C64C64_TN>
|
||||
: MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = complex<double>;
|
||||
using ElementAVal = complex<double>;
|
||||
using ElementBVal = complex<double>;
|
||||
using ElementCVal = complex<double>;
|
||||
};
|
||||
|
||||
// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x4_GC64C64C64GC64_TN>
|
||||
: MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
|
||||
using ElementAVal = complex<double>;
|
||||
using ElementBVal = complex<double>;
|
||||
using ElementCVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////// s32 = s8 * s8 + s32 ///////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_8,_8,_16>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = SM80_8x16_Row;
|
||||
using BLayout = SM80_8x16_Row;
|
||||
using CLayout = SM80_8x8_Row;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32S8S8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN> {};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_16>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <Shape < _4,_8>,Shape < _4,_2>>,
|
||||
Stride<Stride<_64,_1>,Stride<_16,_8>>>;
|
||||
using BLayout = SM80_8x16_Row;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32S8S8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN> {};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_32>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <Shape < _4,_8>,Shape < _4,_2, _2>>,
|
||||
Stride<Stride<_64,_1>,Stride<_16,_8,_256>>>;
|
||||
using BLayout = Layout<Shape <Shape < _4,_8>, Shape <_4, _2>>,
|
||||
Stride<Stride<_32,_1>, Stride<_8,_128>>>;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32S8S8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN> {};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////// s32 = s8 * u8 + s32 ///////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32S8U8S32_TN>
|
||||
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32S8U8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_8x8x16_S32S8U8S32_TN> {};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32S8U8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32S8U8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_16x8x16_S32S8U8S32_TN> {};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32S8U8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = int8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32S8U8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_16x8x32_S32S8U8S32_TN> {};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////// s32 = u8 * s8 + s32 ///////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32U8S8S32_TN>
|
||||
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32U8S8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_8x8x16_S32U8S8S32_TN> {};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32U8S8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32U8S8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_16x8x16_S32U8S8S32_TN> {};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32U8S8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = int8_t;
|
||||
using ElementCVal = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32U8S8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_16x8x32_S32U8S8S32_TN> {};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////// s32 = u8 * u8 + s32 ///////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32U8U8S32_TN>
|
||||
: MMA_Traits<SM80_8x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_8x8x16_S32U8U8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_8x8x16_S32U8U8S32_TN> {};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32U8U8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x16_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x16_S32U8U8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_16x8x16_S32U8U8S32_TN> {};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32U8U8S32_TN>
|
||||
: MMA_Traits<SM80_16x8x32_S32S8S8S32_TN>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = uint8_t;
|
||||
using ElementBVal = uint8_t;
|
||||
using ElementCVal = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x32_S32U8U8S32_TN_SATURATE>
|
||||
: MMA_Traits<SM80_16x8x32_S32U8U8S32_TN> {};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////// s32 = b1 ^ b1 + s32 ///////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC>
|
||||
{
|
||||
using ElementDVal = int32_t;
|
||||
using ElementAVal = cute::uint1b_t;
|
||||
using ElementBVal = cute::uint1b_t;
|
||||
using ElementCVal = int32_t;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_256>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <_32,Shape < _8, _4,_2, _2>>,
|
||||
Stride<_64,Stride<_64,_16,_8,_2048>>>;
|
||||
using BLayout = Layout<Shape <_32,Shape <_32, _2>>,
|
||||
Stride<_32,Stride< _1,_1024>>>;
|
||||
using CLayout = SM80_16x8_Row;
|
||||
};
|
||||
} // end namespace cute
|
||||
132
include/cute/atom/mma_traits_sm90.hpp
Normal file
132
include/cute/atom/mma_traits_sm90.hpp
Normal file
@ -0,0 +1,132 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/mma_sm90.hpp>
|
||||
#include <cute/atom/mma_traits.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////// fp64 = fp64 * fp64 + fp64 ////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x4_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = double;
|
||||
using ElementAVal = double;
|
||||
using ElementBVal = double;
|
||||
using ElementCVal = double;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_4>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <Shape < _4,_8>,_2>,
|
||||
Stride<Stride<_16,_1>,_8>>;
|
||||
using BLayout = Layout<Shape <Shape < _4,_8>,_1>,
|
||||
Stride<Stride< _8,_1>,_0>>;
|
||||
using CLayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
|
||||
Stride<Stride<_32,_1>,Stride<_16,_8>>>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x8_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = double;
|
||||
using ElementAVal = double;
|
||||
using ElementBVal = double;
|
||||
using ElementCVal = double;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_8>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <Shape < _4,_8>,Shape <_2, _2>>,
|
||||
Stride<Stride<_16,_1>,Stride<_8,_64>>>;
|
||||
using BLayout = Layout<Shape <Shape < _4,_8>, _2>,
|
||||
Stride<Stride< _8,_1>,_32>>;
|
||||
using CLayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
|
||||
Stride<Stride<_32,_1>,Stride<_16,_8>>>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x16_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = double;
|
||||
using ElementAVal = double;
|
||||
using ElementBVal = double;
|
||||
using ElementCVal = double;
|
||||
|
||||
using Shape_MNK = Shape<_16,_8,_16>;
|
||||
using ThrID = Layout<_32>;
|
||||
using ALayout = Layout<Shape <Shape < _4,_8>,Shape <_2, _4>>,
|
||||
Stride<Stride<_16,_1>,Stride<_8,_64>>>;
|
||||
using BLayout = Layout<Shape <Shape < _4,_8>, _4>,
|
||||
Stride<Stride< _8,_1>,_32>>;
|
||||
using CLayout = Layout<Shape <Shape < _4,_8>,Shape < _2,_2>>,
|
||||
Stride<Stride<_32,_1>,Stride<_16,_8>>>;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////// cfp64 = cfp64 * cfp64 + cfp64 ////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x4_C64C64C64C64_TN>
|
||||
: MMA_Traits<SM90_16x8x4_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = complex<double>;
|
||||
using ElementAVal = complex<double>;
|
||||
using ElementBVal = complex<double>;
|
||||
using ElementCVal = complex<double>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x8_C64C64C64C64_TN>
|
||||
: MMA_Traits<SM90_16x8x8_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = complex<double>;
|
||||
using ElementAVal = complex<double>;
|
||||
using ElementBVal = complex<double>;
|
||||
using ElementCVal = complex<double>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MMA_Traits<SM90_16x8x16_C64C64C64C64_TN>
|
||||
: MMA_Traits<SM90_16x8x16_F64F64F64F64_TN>
|
||||
{
|
||||
using ElementDVal = complex<double>;
|
||||
using ElementAVal = complex<double>;
|
||||
using ElementBVal = complex<double>;
|
||||
using ElementCVal = complex<double>;
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
2975
include/cute/atom/mma_traits_sm90_gmma.hpp
Normal file
2975
include/cute/atom/mma_traits_sm90_gmma.hpp
Normal file
File diff suppressed because it is too large
Load Diff
121
include/cute/config.hpp
Normal file
121
include/cute/config.hpp
Normal file
@ -0,0 +1,121 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)
|
||||
# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__
|
||||
# define CUTE_DEVICE __forceinline__ __device__
|
||||
# define CUTE_HOST __forceinline__ __host__
|
||||
#else
|
||||
# define CUTE_HOST_DEVICE inline
|
||||
# define CUTE_DEVICE inline
|
||||
# define CUTE_HOST inline
|
||||
#endif // CUTE_HOST_DEVICE, CUTE_DEVICE
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)
|
||||
# define CUTE_UNROLL #pragma unroll
|
||||
# define CUTE_NO_UNROLL #pragma unroll 1
|
||||
#else
|
||||
# define CUTE_UNROLL
|
||||
# define CUTE_NO_UNROLL
|
||||
#endif // CUTE_UNROLL
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)
|
||||
# define CUTE_INLINE_CONSTANT static const __device__
|
||||
#else
|
||||
# define CUTE_INLINE_CONSTANT static constexpr
|
||||
#endif
|
||||
|
||||
// Some versions of GCC < 11 have trouble deducing that a
|
||||
// function with "auto" return type and all of its returns in an "if
|
||||
// constexpr ... else" statement must actually return. Thus, GCC
|
||||
// emits spurious "missing return statement" build warnings.
|
||||
// Developers can suppress these warnings by using the
|
||||
// CUTE_GCC_UNREACHABLE macro, which must be followed by a semicolon.
|
||||
// It's harmless to use the macro for other GCC versions or other
|
||||
// compilers, but it has no effect.
|
||||
#if ! defined(CUTE_GCC_UNREACHABLE)
|
||||
# if defined(__GNUC__) && __GNUC__ < 11
|
||||
// GCC 10, but not 7.5, 9.4.0, or 11, issues "missing return
|
||||
// statement" warnings without this little bit of help.
|
||||
# define CUTE_GCC_UNREACHABLE __builtin_unreachable()
|
||||
# else
|
||||
# define CUTE_GCC_UNREACHABLE
|
||||
# endif
|
||||
#endif
|
||||
|
||||
//
|
||||
// Assertion helpers
|
||||
//
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#define CUTE_STATIC_ASSERT static_assert
|
||||
#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__)
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
# define CUTE_RUNTIME_ASSERT(x) asm volatile ("brkpt;\n" ::: "memory")
|
||||
#else
|
||||
# define CUTE_RUNTIME_ASSERT(x) assert(0 && x)
|
||||
#endif
|
||||
|
||||
//
|
||||
// IO
|
||||
//
|
||||
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
|
||||
//
|
||||
// Support
|
||||
//
|
||||
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
//
|
||||
// Basic types
|
||||
//
|
||||
|
||||
#include <cute/numeric/int.hpp>
|
||||
#include <cute/numeric/real.hpp>
|
||||
#include <cute/numeric/half.hpp>
|
||||
#include <cute/numeric/float8.hpp>
|
||||
#include <cute/numeric/bfloat.hpp>
|
||||
#include <cute/numeric/tfloat.hpp>
|
||||
#include <cute/numeric/complex.hpp>
|
||||
|
||||
//
|
||||
// Debugging utilities
|
||||
//
|
||||
|
||||
#include <cute/util/print.hpp>
|
||||
#include <cute/util/debug.hpp>
|
||||
70
include/cute/container/alignment.hpp
Normal file
70
include/cute/container/alignment.hpp
Normal file
@ -0,0 +1,70 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/numeric/int.hpp>
|
||||
#include <cute/numeric/math.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
// Test if a pointer is aligned to N bytes
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool
|
||||
is_byte_aligned(void const* const ptr)
|
||||
{
|
||||
static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2 in alignment check");
|
||||
return (reinterpret_cast<uintptr_t>(ptr) & (N-1)) == 0;
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
# define CUTE_ALIGNAS(n) __align__(n)
|
||||
#else
|
||||
# define CUTE_ALIGNAS(n) alignas(n)
|
||||
#endif
|
||||
|
||||
template <std::size_t Alignment>
|
||||
struct aligned_struct {};
|
||||
|
||||
template <> struct CUTE_ALIGNAS( 1) aligned_struct< 1> {};
|
||||
template <> struct CUTE_ALIGNAS( 2) aligned_struct< 2> {};
|
||||
template <> struct CUTE_ALIGNAS( 4) aligned_struct< 4> {};
|
||||
template <> struct CUTE_ALIGNAS( 8) aligned_struct< 8> {};
|
||||
template <> struct CUTE_ALIGNAS( 16) aligned_struct< 16> {};
|
||||
template <> struct CUTE_ALIGNAS( 32) aligned_struct< 32> {};
|
||||
template <> struct CUTE_ALIGNAS( 64) aligned_struct< 64> {};
|
||||
template <> struct CUTE_ALIGNAS(128) aligned_struct<128> {};
|
||||
template <> struct CUTE_ALIGNAS(256) aligned_struct<256> {};
|
||||
|
||||
} // end namespace cute
|
||||
282
include/cute/container/array.hpp
Normal file
282
include/cute/container/array.hpp
Normal file
@ -0,0 +1,282 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class T, std::size_t N>
|
||||
struct array
|
||||
{
|
||||
using value_type = T;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using reference = value_type&;
|
||||
using const_reference = const value_type&;
|
||||
using pointer = value_type*;
|
||||
using const_pointer = const value_type*;
|
||||
using iterator = pointer;
|
||||
using const_iterator = const_pointer;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference operator[](size_type pos)
|
||||
{
|
||||
return begin()[pos];
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference operator[](size_type pos) const
|
||||
{
|
||||
return begin()[pos];
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference front()
|
||||
{
|
||||
return *begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference front() const
|
||||
{
|
||||
return *begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference back()
|
||||
{
|
||||
// return *rbegin();
|
||||
return operator[](N-1);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference back() const
|
||||
{
|
||||
// return *rbegin();
|
||||
return operator[](N-1);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T* data()
|
||||
{
|
||||
return __elems_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T const* data() const
|
||||
{
|
||||
return __elems_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator begin()
|
||||
{
|
||||
return data();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator begin() const
|
||||
{
|
||||
return data();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator cbegin()
|
||||
{
|
||||
return begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator cbegin() const
|
||||
{
|
||||
return begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator end()
|
||||
{
|
||||
return data() + size();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator end() const
|
||||
{
|
||||
return data() + size();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator cend()
|
||||
{
|
||||
return end();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator cend() const
|
||||
{
|
||||
return end();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool empty() const
|
||||
{
|
||||
return size() == 0;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
size_type size() const
|
||||
{
|
||||
return N;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
size_type max_size() const
|
||||
{
|
||||
return size();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void fill(const T& value)
|
||||
{
|
||||
for (auto& e : *this) {
|
||||
e = value;
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void clear()
|
||||
{
|
||||
fill(T(0));
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void swap(array& other)
|
||||
{
|
||||
using std::swap;
|
||||
for (size_type i = 0; i < size(); ++i) {
|
||||
swap((*this)[i], other[i]);
|
||||
}
|
||||
}
|
||||
|
||||
value_type __elems_[N > 0 ? N : 1];
|
||||
};
|
||||
|
||||
|
||||
template<class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator==(array<T,N> const& lhs, array<T,N> const& rhs)
|
||||
{
|
||||
for (std::size_t i = 0; i < N; ++i) {
|
||||
if (lhs[i] != rhs[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void clear(array<T,N>& a)
|
||||
{
|
||||
a.fill(T(0));
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void fill(array<T,N>& a, T const& value)
|
||||
{
|
||||
a.fill(value);
|
||||
}
|
||||
|
||||
template<class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void swap(array<T,N>& a, array<T,N>& b)
|
||||
{
|
||||
a.swap(b);
|
||||
}
|
||||
|
||||
} // end cute
|
||||
|
||||
|
||||
//
|
||||
// Specialize tuple-related functionality for cute::array
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template<std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T& get(array<T,N>& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return a[I];
|
||||
}
|
||||
|
||||
template<std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T const& get(array<T,N> const& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return a[I];
|
||||
}
|
||||
|
||||
template<std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T&& get(array<T,N>&& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return std::move(a[I]);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
||||
template <class T, std::size_t N>
|
||||
struct tuple_size<cute::array<T,N>>
|
||||
: std::integral_constant<std::size_t, N>
|
||||
{};
|
||||
|
||||
template <std::size_t I, class T, std::size_t N>
|
||||
struct tuple_element<I, cute::array<T,N>>
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
} // end std
|
||||
276
include/cute/container/array_aligned.hpp
Normal file
276
include/cute/container/array_aligned.hpp
Normal file
@ -0,0 +1,276 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/container/alignment.hpp>
|
||||
#include <cute/numeric/int.hpp>
|
||||
#include <cute/numeric/math.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <typename T, std::size_t N, std::size_t Alignment = 16>
|
||||
struct array_aligned
|
||||
: public aligned_struct<Alignment>
|
||||
{
|
||||
/// Make sure the Alignment makes sense wrt the size of elements.
|
||||
static_assert(Alignment == 16 || Alignment >= sizeof(T), "Alignment is too small");
|
||||
/// Alignment must be a power of two
|
||||
static_assert(has_single_bit(Alignment), "Alignment must be a power of two");
|
||||
|
||||
using value_type = T;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using reference = value_type&;
|
||||
using const_reference = const value_type&;
|
||||
using pointer = value_type*;
|
||||
using const_pointer = const value_type*;
|
||||
using iterator = pointer;
|
||||
using const_iterator = const_pointer;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference operator[](size_type pos)
|
||||
{
|
||||
return begin()[pos];
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference operator[](size_type pos) const
|
||||
{
|
||||
return begin()[pos];
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference front()
|
||||
{
|
||||
return *begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference front() const
|
||||
{
|
||||
return *begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference back()
|
||||
{
|
||||
// return *rbegin();
|
||||
return operator[](N-1);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference back() const
|
||||
{
|
||||
// return *rbegin();
|
||||
return operator[](N-1);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T* data()
|
||||
{
|
||||
return reinterpret_cast<T*>(storage);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T const* data() const
|
||||
{
|
||||
return reinterpret_cast<T const*>(storage);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator begin()
|
||||
{
|
||||
return data();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator begin() const
|
||||
{
|
||||
return data();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator cbegin()
|
||||
{
|
||||
return begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator cbegin() const
|
||||
{
|
||||
return begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator end()
|
||||
{
|
||||
return data() + size();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator end() const
|
||||
{
|
||||
return data() + size();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator cend()
|
||||
{
|
||||
return end();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator cend() const
|
||||
{
|
||||
return end();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool empty() const
|
||||
{
|
||||
return size() == 0;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
size_type size() const
|
||||
{
|
||||
return N;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
size_type max_size() const
|
||||
{
|
||||
return size();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void fill(T const& value)
|
||||
{
|
||||
for (auto& e : *this) {
|
||||
e = value;
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void clear()
|
||||
{
|
||||
fill(T(0));
|
||||
}
|
||||
|
||||
// Not private, we want trivial type
|
||||
//private:
|
||||
|
||||
/// Storage type to use for Elements
|
||||
using StorageType = typename uint_byte<static_cast<int>(Alignment)>::type;
|
||||
|
||||
/// Ensure that there's enough storage for all elements
|
||||
static_assert(sizeof(StorageType) <= Alignment, "StorageType is too big for given alignment");
|
||||
|
||||
/// Number of elements in the storage
|
||||
static constexpr std::size_t storageN = (sizeof(T)*N + sizeof(StorageType) - 1) / sizeof(StorageType);
|
||||
|
||||
/// The storage.
|
||||
StorageType storage[storageN > 0 ? storageN : 1];
|
||||
};
|
||||
|
||||
//
|
||||
// Operators
|
||||
//
|
||||
|
||||
template <typename T, std::size_t N, std::size_t Alignment>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void clear(array_aligned<T, N, Alignment>& a)
|
||||
{
|
||||
a.clear();
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N, std::size_t Alignment>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void fill(array_aligned<T, N, Alignment>& a, T const& value)
|
||||
{
|
||||
a.fill(value);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
//
|
||||
// Specialize tuple-related functionality for cute::array
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T& get(array_aligned<T,N>& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return a[I];
|
||||
}
|
||||
|
||||
template <std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T const& get(array_aligned<T,N> const& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return a[I];
|
||||
}
|
||||
|
||||
template <std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T&& get(array_aligned<T,N>&& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return std::move(a[I]);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
||||
template <class T, std::size_t N>
|
||||
struct tuple_size<cute::array_aligned<T,N>>
|
||||
: std::integral_constant<std::size_t, N>
|
||||
{};
|
||||
|
||||
template <std::size_t I, class T, std::size_t N>
|
||||
struct tuple_element<I, cute::array_aligned<T,N>>
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
} // end std
|
||||
613
include/cute/container/array_subbyte.hpp
Normal file
613
include/cute/container/array_subbyte.hpp
Normal file
@ -0,0 +1,613 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Statically sized array of elements that accommodates subbyte trivial types
|
||||
in a packed storage.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/numeric/int.hpp> // sizeof_bits
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Statically sized array for any data type
|
||||
template <class T, std::size_t N>
|
||||
class array_subbyte
|
||||
{
|
||||
public:
|
||||
|
||||
/// Number of total bits in the array
|
||||
static constexpr int kSizeBits = sizeof_bits<T>::value * N;
|
||||
|
||||
/// Storage type
|
||||
using Storage = typename std::conditional<
|
||||
(kSizeBits % 32) == 0,
|
||||
uint32_t,
|
||||
typename std::conditional<
|
||||
(kSizeBits % 16) == 0,
|
||||
uint16_t,
|
||||
uint8_t
|
||||
>::type
|
||||
>::type;
|
||||
|
||||
|
||||
/// Number of logical elements per stored object
|
||||
static constexpr int kElementsPerStoredItem = sizeof_bits<Storage>::value / sizeof_bits<T>::value;
|
||||
|
||||
/// Number of storage elements
|
||||
static constexpr std::size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem;
|
||||
|
||||
/// Bitmask for covering one item
|
||||
static constexpr Storage bit_mask_ = ((Storage(1) << sizeof_bits<T>::value) - 1);
|
||||
|
||||
//
|
||||
// C++ standard members with reference and iterator types omitted
|
||||
//
|
||||
|
||||
using value_type = T;
|
||||
using pointer = value_type*;
|
||||
using const_pointer = value_type const*;
|
||||
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
|
||||
//
|
||||
// References
|
||||
//
|
||||
|
||||
/// Reference object inserts or extracts sub-byte items
|
||||
class reference {
|
||||
/// Pointer to storage element
|
||||
Storage* ptr_;
|
||||
|
||||
/// Index into elements packed into Storage object
|
||||
int idx_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default ctor
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference() : ptr_(nullptr), idx_(0) {}
|
||||
|
||||
/// Ctor
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference(Storage* ptr, int idx = 0) : ptr_(ptr), idx_(idx) {}
|
||||
|
||||
/// Assignment
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference& operator=(T x) {
|
||||
Storage item = (reinterpret_cast<Storage const&>(x) & bit_mask_);
|
||||
Storage kUpdateMask = Storage(~(bit_mask_ << (idx_ * sizeof_bits<T>::value)));
|
||||
*ptr_ = Storage((*ptr_ & kUpdateMask) | (item << (idx_ * sizeof_bits<T>::value)));
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T get() const {
|
||||
Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits<T>::value)) & bit_mask_);
|
||||
return reinterpret_cast<T const&>(item);
|
||||
}
|
||||
|
||||
/// Extract to type T -- disable if T == bool
|
||||
template <class U = T, __CUTE_REQUIRES(not std::is_same<U,bool>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator T() const {
|
||||
return get();
|
||||
}
|
||||
|
||||
// Extract to bool -- potentially faster impl
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator bool() const {
|
||||
return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits<T>::value)));
|
||||
}
|
||||
|
||||
/// Explicit cast to int
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
explicit operator int() const {
|
||||
return int(get());
|
||||
}
|
||||
|
||||
/// Explicit cast to float
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
explicit operator float() const {
|
||||
return float(get());
|
||||
}
|
||||
};
|
||||
|
||||
/// Reference object extracts sub-byte items
|
||||
class const_reference {
|
||||
|
||||
/// Pointer to storage element
|
||||
Storage const* ptr_;
|
||||
|
||||
/// Index into elements packed into Storage object
|
||||
int idx_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default ctor
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference(): ptr_(nullptr), idx_(0) { }
|
||||
|
||||
/// Ctor
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const T get() const {
|
||||
Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits<T>::value)) & bit_mask_);
|
||||
return reinterpret_cast<T const&>(item);
|
||||
}
|
||||
|
||||
/// Extract to type T -- disable if T == bool
|
||||
template <class U = T, __CUTE_REQUIRES(not std::is_same<U,bool>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator T() const {
|
||||
return get();
|
||||
}
|
||||
|
||||
// Extract to bool -- potentially faster impl
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator bool() const {
|
||||
return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits<T>::value)));
|
||||
}
|
||||
|
||||
/// Explicit cast to int
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
explicit operator int() const {
|
||||
return int(get());
|
||||
}
|
||||
|
||||
/// Explicit cast to float
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
explicit operator float() const {
|
||||
return float(get());
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Iterators
|
||||
//
|
||||
|
||||
/// Bidirectional iterator over elements
|
||||
class iterator {
|
||||
|
||||
/// Pointer to storage element
|
||||
Storage* ptr_;
|
||||
|
||||
/// Index into elements packed into Storage object
|
||||
int idx_;
|
||||
|
||||
public:
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator(): ptr_(nullptr), idx_(0) { }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator(Storage* ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator& operator++() {
|
||||
++idx_;
|
||||
if (idx_ == kElementsPerStoredItem) {
|
||||
++ptr_;
|
||||
idx_ = 0;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator& operator--() {
|
||||
if (idx_) {
|
||||
--idx_;
|
||||
} else {
|
||||
--ptr_;
|
||||
idx_ = kElementsPerStoredItem - 1;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator operator++(int) {
|
||||
iterator ret(*this);
|
||||
++(*this);
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator operator--(int) {
|
||||
iterator ret(*this);
|
||||
--(*this);
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator& operator+=(int k) {
|
||||
idx_ += k;
|
||||
ptr_ += idx_ / kElementsPerStoredItem;
|
||||
idx_ = idx_ % kElementsPerStoredItem;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator operator+(int k) const {
|
||||
return iterator(ptr_,idx_) += k;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference operator*() const {
|
||||
return reference(ptr_, idx_);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference operator[](int k) const {
|
||||
return *(*this + k);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator==(iterator const& other) const {
|
||||
return ptr_ == other.ptr_ && idx_ == other.idx_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator!=(iterator const& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
/// Bidirectional constant iterator over elements
|
||||
class const_iterator {
|
||||
|
||||
/// Pointer to storage element
|
||||
Storage const* ptr_;
|
||||
|
||||
/// Index into elements packed into Storage object
|
||||
int idx_;
|
||||
|
||||
public:
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator(): ptr_(nullptr), idx_(0) { }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator& operator++() {
|
||||
++idx_;
|
||||
if (idx_ == kElementsPerStoredItem) {
|
||||
++ptr_;
|
||||
idx_ = 0;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator& operator--() {
|
||||
if (idx_) {
|
||||
--idx_;
|
||||
} else {
|
||||
--ptr_;
|
||||
idx_ = kElementsPerStoredItem - 1;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator operator++(int) {
|
||||
iterator ret(*this);
|
||||
++idx_;
|
||||
if (idx_ == kElementsPerStoredItem) {
|
||||
++ptr_;
|
||||
idx_ = 0;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator operator--(int) {
|
||||
iterator ret(*this);
|
||||
if (idx_) {
|
||||
--idx_;
|
||||
} else {
|
||||
--ptr_;
|
||||
idx_ = kElementsPerStoredItem - 1;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator& operator+=(int k) {
|
||||
idx_ += k;
|
||||
ptr_ += idx_ / kElementsPerStoredItem;
|
||||
idx_ = idx_ % kElementsPerStoredItem;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator operator+(int k) const {
|
||||
return const_iterator(ptr_,idx_) += k;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference operator*() const {
|
||||
return const_reference(ptr_, idx_);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference operator[](int k) const {
|
||||
return *(*this + k);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator==(iterator const& other) const {
|
||||
return ptr_ == other.ptr_ && idx_ == other.idx_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator!=(iterator const& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Internal storage
|
||||
Storage storage[kStorageElements];
|
||||
|
||||
public:
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
array_subbyte() { }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
array_subbyte(array_subbyte const& x) {
|
||||
CUTE_UNROLL
|
||||
for (unsigned i = 0; i < kStorageElements; ++i) {
|
||||
storage[i] = x.storage[i];
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
size_type size() const {
|
||||
return N;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
size_type max_size() const {
|
||||
return N;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool empty() const {
|
||||
return !N;
|
||||
}
|
||||
|
||||
/// Efficient clear method
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void clear() {
|
||||
CUTE_UNROLL
|
||||
for (unsigned i = 0; i < kStorageElements; ++i) {
|
||||
storage[i] = Storage(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Efficient fill method
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void fill(T const& value) {
|
||||
Storage item = (reinterpret_cast<Storage const&>(value) & bit_mask_);
|
||||
|
||||
// Reproduce the value over the bits of the storage item
|
||||
CUTE_UNROLL
|
||||
for (unsigned s = sizeof_bits<T>::value; s < sizeof_bits<Storage>::value; s *= 2) {
|
||||
item |= item << s;
|
||||
}
|
||||
|
||||
CUTE_UNROLL
|
||||
for (unsigned i = 0; i < kStorageElements; ++i) {
|
||||
storage[i] = item;
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference at(size_type pos) {
|
||||
return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference at(size_type pos) const {
|
||||
return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference operator[](size_type pos) {
|
||||
return at(pos);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference operator[](size_type pos) const {
|
||||
return at(pos);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference front() {
|
||||
return at(0);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference front() const {
|
||||
return at(0);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference back() {
|
||||
return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_reference back() const {
|
||||
return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
pointer data() {
|
||||
return reinterpret_cast<pointer>(storage);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_pointer data() const {
|
||||
return reinterpret_cast<const_pointer>(storage);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Storage* raw_data() {
|
||||
return storage;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Storage const* raw_data() const {
|
||||
return storage;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator begin() {
|
||||
return iterator(storage);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator begin() const {
|
||||
return const_iterator(storage);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator cbegin() const {
|
||||
return begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator end() {
|
||||
return iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator end() const {
|
||||
return const_iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const_iterator cend() const {
|
||||
return end();
|
||||
}
|
||||
|
||||
//
|
||||
// Comparison operators
|
||||
//
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
// Operators
|
||||
//
|
||||
|
||||
template <class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void clear(array_subbyte<T,N>& a)
|
||||
{
|
||||
a.clear();
|
||||
}
|
||||
|
||||
template <class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void fill(array_subbyte<T,N>& a, T const& value)
|
||||
{
|
||||
a.fill(value);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
|
||||
//
|
||||
// Specialize tuple-related functionality for cute::array_subbyte
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T& get(array_subbyte<T,N>& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return a[I];
|
||||
}
|
||||
|
||||
template <std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T const& get(array_subbyte<T,N> const& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return a[I];
|
||||
}
|
||||
|
||||
template <std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T&& get(array_subbyte<T,N>&& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return std::move(a[I]);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
||||
template <class T, std::size_t N>
|
||||
struct tuple_size<cute::array_subbyte<T,N>>
|
||||
: std::integral_constant<std::size_t, N>
|
||||
{};
|
||||
|
||||
template <std::size_t I, class T, std::size_t N>
|
||||
struct tuple_element<I, cute::array_subbyte<T,N>>
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
} // end namespace std
|
||||
274
include/cute/container/array_view.hpp
Normal file
274
include/cute/container/array_view.hpp
Normal file
@ -0,0 +1,274 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class T, std::size_t N>
|
||||
struct array_view
|
||||
{
|
||||
using value_type = T;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using reference = value_type&;
|
||||
using const_reference = const value_type&;
|
||||
using pointer = value_type*;
|
||||
using const_pointer = const value_type*;
|
||||
using iterator = pointer;
|
||||
using const_iterator = const_pointer;
|
||||
|
||||
array_view(array<T,N>& a)
|
||||
: __elems_(a.data()) {}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
reference operator[](size_type pos)
|
||||
{
|
||||
return begin()[pos];
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
const_reference operator[](size_type pos) const
|
||||
{
|
||||
return begin()[pos];
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
reference front()
|
||||
{
|
||||
return *begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
const_reference front() const
|
||||
{
|
||||
return *begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
reference back()
|
||||
{
|
||||
// return *rbegin();
|
||||
return operator[](N-1);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
const_reference back() const
|
||||
{
|
||||
// return *rbegin();
|
||||
return operator[](N-1);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
T* data()
|
||||
{
|
||||
return __elems_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
const T* data() const
|
||||
{
|
||||
return __elems_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
iterator begin()
|
||||
{
|
||||
return data();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
const_iterator begin() const
|
||||
{
|
||||
return data();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
const_iterator cbegin()
|
||||
{
|
||||
return begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
const_iterator cbegin() const
|
||||
{
|
||||
return begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
iterator end()
|
||||
{
|
||||
return data() + size();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
const_iterator end() const
|
||||
{
|
||||
return data() + size();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
const_iterator cend()
|
||||
{
|
||||
return end();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
const_iterator cend() const
|
||||
{
|
||||
return end();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool empty() const
|
||||
{
|
||||
return size() == 0;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
size_type size() const
|
||||
{
|
||||
return N;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
size_type max_size() const
|
||||
{
|
||||
return size();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void fill(const T& value)
|
||||
{
|
||||
for(auto& e : *this)
|
||||
{
|
||||
e = value;
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void swap(array_view& other)
|
||||
{
|
||||
using std::swap;
|
||||
swap(__elems_, other.__elems_);
|
||||
}
|
||||
|
||||
value_type* __elems_;
|
||||
};
|
||||
|
||||
|
||||
template<class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE
|
||||
bool operator==(const array_view<T,N>& lhs, const array_view<T,N>& rhs)
|
||||
{
|
||||
for(std::size_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(lhs[i] != rhs[i]) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N>
|
||||
CUTE_HOST_DEVICE
|
||||
void clear(array_view<T, N>& a)
|
||||
{
|
||||
a.fill(T(0));
|
||||
}
|
||||
|
||||
template<class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE
|
||||
void swap(array_view<T,N>& a, array_view<T,N>& b)
|
||||
{
|
||||
a.swap(b);
|
||||
}
|
||||
|
||||
} // end cute
|
||||
|
||||
|
||||
//
|
||||
// Specialize tuple-related functionality for cute::array_view
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template<std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T&
|
||||
get(array_view<T,N>& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return a[I];
|
||||
}
|
||||
|
||||
template<std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
const T&
|
||||
get(const array_view<T,N>& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return a[I];
|
||||
}
|
||||
|
||||
template<std::size_t I, class T, std::size_t N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T&&
|
||||
get(array_view<T,N>&& a)
|
||||
{
|
||||
static_assert(I < N, "Index out of range");
|
||||
return std::move(a[I]);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
||||
template<class T, std::size_t N>
|
||||
struct tuple_size<cute::array_view<T,N>>
|
||||
: std::integral_constant<std::size_t, N>
|
||||
{};
|
||||
|
||||
template<std::size_t I, class T, std::size_t N>
|
||||
struct tuple_element<I, cute::array_view<T,N>>
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
} // end std
|
||||
131
include/cute/container/bit_field.hpp
Normal file
131
include/cute/container/bit_field.hpp
Normal file
@ -0,0 +1,131 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Portable bit field that supports byte and word straddling that can
|
||||
be used in unions to bit-wise define parameters.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/numeric/int.hpp> // uint_bit_t
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
class dummy_type {};
|
||||
|
||||
template <uint32_t BitStart, uint32_t NumBits, class OtherValueType = dummy_type>
|
||||
struct bit_field
|
||||
{
|
||||
static_assert(0 < NumBits && NumBits <= 64, "bit_fields with more than 64 bits are not supported.");
|
||||
|
||||
// value_type: Use the smallest value type that fits NumBits
|
||||
static constexpr uint32_t value_type_bits = (NumBits <= 8) ? 8 :
|
||||
(NumBits <= 16) ? 16 :
|
||||
(NumBits <= 32) ? 32 : 64;
|
||||
using value_type = cute::uint_bit_t<value_type_bits>;
|
||||
// storage_type: Use the smallest storage_type that avoids boundary crossing
|
||||
static constexpr uint32_t storage_type_bits = (BitStart / 8 == (BitStart + NumBits - 1) / 8) ? 8 :
|
||||
(BitStart / 16 == (BitStart + NumBits - 1) / 16) ? 16 :
|
||||
(BitStart / 32 == (BitStart + NumBits - 1) / 32) ? 32 : 64;
|
||||
using storage_type = cute::uint_bit_t<storage_type_bits>;
|
||||
|
||||
static_assert(sizeof(OtherValueType) == sizeof(value_type) || std::is_same<OtherValueType,dummy_type>::value,
|
||||
"sizeof(OtherValueType) must be same as sizeof(value_type).");
|
||||
|
||||
// Number of storage values needed: ceil_div(BitStart + NumBits, storage_type_bits)
|
||||
static constexpr uint32_t N = (BitStart + NumBits + storage_type_bits - 1) / storage_type_bits;
|
||||
// Index of storage value for BitStart
|
||||
static constexpr uint32_t idx = BitStart / storage_type_bits;
|
||||
// Bit of data_[idx] for BitStart
|
||||
static constexpr uint32_t bit_lo = BitStart % storage_type_bits;
|
||||
// Number of bits in data_[idx] used for NumBits if straddling, else 0
|
||||
static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0;
|
||||
|
||||
// NumBits mask
|
||||
static constexpr value_type mask = (NumBits < 64) ? ((uint64_t(1) << NumBits) - 1) : uint64_t(-1);
|
||||
// NumBits mask for BitStart
|
||||
static constexpr storage_type mask_lo = storage_type(mask) << bit_lo;
|
||||
// NumBits mask for leftover bits in data_[idx+1] if straddling, else 0
|
||||
static constexpr storage_type mask_hi = (idx + 1 < N) ? (storage_type(mask) >> bit_hi) : 0;
|
||||
|
||||
storage_type data_[N];
|
||||
|
||||
// Get value
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
value_type get() const {
|
||||
storage_type result = (data_[idx] & mask_lo) >> bit_lo;
|
||||
if constexpr (bit_hi) {
|
||||
result |= (data_[idx+1] & mask_hi) << bit_hi;
|
||||
}
|
||||
return static_cast<value_type>(result);
|
||||
}
|
||||
|
||||
// Set value
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void set(value_type x) {
|
||||
storage_type item = static_cast<storage_type>(x & mask);
|
||||
data_[idx] = static_cast<storage_type>((data_[idx] & ~mask_lo) | (item << bit_lo));
|
||||
if constexpr (bit_hi) {
|
||||
data_[idx+1] = static_cast<storage_type>((data_[idx+1] & ~mask_hi) | (item >> bit_hi));
|
||||
}
|
||||
}
|
||||
|
||||
// Assign value
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bit_field& operator=(value_type x) {
|
||||
set(x);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Cast to value
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator value_type () const {
|
||||
return get();
|
||||
}
|
||||
|
||||
// Assign OtherValueType
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bit_field& operator=(OtherValueType x) {
|
||||
return *this = *reinterpret_cast<value_type*>(&x);
|
||||
}
|
||||
|
||||
// Cast to OtherValueType
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator OtherValueType () const {
|
||||
value_type x = get();
|
||||
return *reinterpret_cast<OtherValueType*>(&x);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
671
include/cute/container/tuple.hpp
Normal file
671
include/cute/container/tuple.hpp
Normal file
@ -0,0 +1,671 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
#include <cute/numeric/integral_constant.hpp> // cute::true_type, cute::false_type
|
||||
//#include <cute/container/array.hpp> // Advanced optimizations
|
||||
|
||||
#if 0
|
||||
//
|
||||
// Use of agency::tuple is functional, but is over-engineered for our purposes...
|
||||
// This tends to result in slow compilation times and unintentionally propagated cvref types
|
||||
//
|
||||
|
||||
#include <agency/tuple.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
using agency::tuple;
|
||||
|
||||
using agency::make_tuple;
|
||||
using agency::tuple_cat;
|
||||
|
||||
} // end namespace cute
|
||||
#endif
|
||||
|
||||
// cute::tuple is like std::tuple, with two differences.
|
||||
//
|
||||
// 1. It works on both host and device.
|
||||
// 2. Its template arguments must be semiregular types.
|
||||
//
|
||||
// Semiregular types are default constructible and copyable.
|
||||
// They include "value types" like int or float,
|
||||
// but do _not_ include references like int& or float&.
|
||||
// (See std::tie for an example of a tuple of references.)
|
||||
//
|
||||
// This is simplified over the implementation in std:: and agency:: by ignoring much of
|
||||
// the conversion SFINAE, special overloading, and avoiding cvref template types.
|
||||
// Furthermore, the empty base optimization (EBO) is MORE aggressive by avoiding
|
||||
// construction calls, and ignoring any need for unique element addresses.
|
||||
//
|
||||
// Over the agency::tuple implementation, this appears to accelerate compilation times by over 3x.
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
// EBO stands for "empty base optimization."
|
||||
// We use this technique to ensure that cute::tuple
|
||||
// doesn't need to waste space storing any template arguments
|
||||
// of cute::tuple that have no data (like integral_constant).
|
||||
// Otherwise, cute::tuple would need to spend at least 1 byte
|
||||
// for each of its template arguments.
|
||||
//
|
||||
// EBO always "holds" a single value of type T.
|
||||
// N is like an array index that TupleBase uses
|
||||
// to access the desired tuple element.
|
||||
template <std::size_t N, class T, bool IsEmpty = std::is_empty<T>::value>
|
||||
struct EBO;
|
||||
|
||||
// Specialization for types T that have no data;
|
||||
// the "static tuple leaf." Valid T here include
|
||||
// integral_constant<U, Value>, Int<Value>,
|
||||
// and any other semiregular type
|
||||
// for which std::is_empty_v<T> is true.
|
||||
template <std::size_t N, class T>
|
||||
struct EBO<N, T, true>
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
EBO() {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
EBO(T const&) {}
|
||||
};
|
||||
|
||||
template <std::size_t N, class T>
|
||||
CUTE_HOST_DEVICE constexpr T getv(EBO<N, T, true> const&)
|
||||
{ return {}; }
|
||||
|
||||
// Specialization for types T that are not empty;
|
||||
// the "dynamic tuple leaf." Valid T here include int,
|
||||
// any other integral or floating-point type,
|
||||
// or any semiregular type for which std::is_empty_v<T> is false.
|
||||
template <std::size_t N, class T>
|
||||
struct EBO<N, T, false>
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
EBO() : t_{} {}
|
||||
|
||||
template <class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
EBO(U const& u) : t_{u} {}
|
||||
|
||||
T t_;
|
||||
};
|
||||
|
||||
template <std::size_t N, class T>
|
||||
CUTE_HOST_DEVICE constexpr T const& getv(EBO<N, T, false> const& x)
|
||||
{ return x.t_; }
|
||||
|
||||
template <std::size_t N, class T>
|
||||
CUTE_HOST_DEVICE constexpr T& getv(EBO<N, T, false>& x)
|
||||
{ return x.t_; }
|
||||
|
||||
template <std::size_t N, class T>
|
||||
CUTE_HOST_DEVICE constexpr T&& getv(EBO<N, T, false>&& x)
|
||||
{ return static_cast<T&&>(x.t_); }
|
||||
|
||||
template <class IdxSeq, class... T>
|
||||
struct TupleBase;
|
||||
|
||||
// Base class of cute::tuple.
|
||||
// It inherits from EBO<i, t> for each (i, t) in (I..., T...).
|
||||
// The actual storage (for nonempty t) lives in the base classes.
|
||||
// index_sequence is a way to wrap up a sequence of zero or more
|
||||
// compile-time integer values in a single type.
|
||||
// We only ever use index_sequence<0, 1, ..., sizeof...(T)> in practice,
|
||||
// as the type alias TupleBase below indicates.
|
||||
template <std::size_t... I, class... T>
|
||||
struct TupleBase<std::index_sequence<I...>, T...>
|
||||
: EBO<I,T>...
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TupleBase() {}
|
||||
|
||||
template <class... U>
|
||||
CUTE_HOST_DEVICE constexpr explicit
|
||||
TupleBase(U const&... u)
|
||||
: EBO<I,T>(u)... {}
|
||||
|
||||
template <class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TupleBase(TupleBase<std::index_sequence<I...>, U...> const& u)
|
||||
: EBO<I,T>(getv(static_cast<EBO<I,U> const&>(u)))... {}
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// make_index_sequence<K> returns index_sequence<0, 1, ..., K-1>.
|
||||
template <class... T>
|
||||
using TupleBase = detail::TupleBase<std::make_index_sequence<sizeof...(T)>, T...>;
|
||||
|
||||
// This is the actual cute::tuple class.
|
||||
// The storage (if any) lives in TupleBase's EBO base classes.
|
||||
template <class... T>
|
||||
struct tuple : TupleBase<T...>
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple() {}
|
||||
|
||||
template <class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple(U const&... u) : TupleBase<T...>(u...) {}
|
||||
|
||||
template <class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple(tuple<U...> const& u)
|
||||
: TupleBase<T...>(static_cast<TupleBase<U...> const&>(u)) {}
|
||||
};
|
||||
|
||||
//
|
||||
// get for cute::tuple (just like std::get for std::tuple)
|
||||
//
|
||||
|
||||
template <std::size_t I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
get(tuple<T...> const& t) noexcept
|
||||
{
|
||||
static_assert(I < sizeof...(T), "Index out of range");
|
||||
return detail::getv<I>(t);
|
||||
}
|
||||
|
||||
template <std::size_t I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
get(tuple<T...>& t) noexcept
|
||||
{
|
||||
static_assert(I < sizeof...(T), "Index out of range");
|
||||
return detail::getv<I>(t);
|
||||
}
|
||||
|
||||
template <std::size_t I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
get(tuple<T...>&& t) noexcept
|
||||
{
|
||||
static_assert(I < sizeof...(T), "Index out of range");
|
||||
return detail::getv<I>(static_cast<tuple<T...>&&>(t));
|
||||
}
|
||||
|
||||
//
|
||||
// Custom is_tuple trait simply checks the existence of std::tuple_size
|
||||
// and assumes std::get<I>(.), std::tuple_element<I,.>
|
||||
//
|
||||
namespace detail {
|
||||
|
||||
template <class T>
|
||||
std::integral_constant<bool, std::tuple_size<T>::value >= 0> has_tuple_size(int);
|
||||
|
||||
template <class T>
|
||||
std::false_type has_tuple_size(...);
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T>
|
||||
struct is_tuple : decltype(detail::has_tuple_size<T>(0)) {};
|
||||
|
||||
//
|
||||
// make_tuple (value-based implementation)
|
||||
//
|
||||
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple<T...>
|
||||
make_tuple(T const&... t)
|
||||
{
|
||||
return {t...};
|
||||
}
|
||||
|
||||
//
|
||||
// tuple_cat concatenates multiple cute::tuple into a single cute::tuple,
|
||||
// just like std::tuple_cat for std::tuple.
|
||||
//
|
||||
|
||||
#if 0
|
||||
// Original implementation
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T0, class T1,
|
||||
std::size_t... I0, std::size_t... I1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1,
|
||||
std::index_sequence<I0...>, std::index_sequence<I1...>)
|
||||
{
|
||||
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple<>
|
||||
tuple_cat()
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class Tuple,
|
||||
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Tuple const&
|
||||
tuple_cat(Tuple const& t)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
|
||||
template <class T0, class T1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1)
|
||||
{
|
||||
return detail::tuple_cat(t0, t1,
|
||||
std::make_index_sequence<std::tuple_size<T0>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T1>::value>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts)
|
||||
{
|
||||
return cute::tuple_cat(cute::tuple_cat(t0,t1),t2,ts...);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// Extended implementation
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T0, class T1,
|
||||
std::size_t... I0, std::size_t... I1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1,
|
||||
std::index_sequence<I0...>, std::index_sequence<I1...>)
|
||||
{
|
||||
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)...);
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2,
|
||||
std::size_t... I0, std::size_t... I1, std::size_t... I2>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2,
|
||||
std::index_sequence<I0...>, std::index_sequence<I1...>, std::index_sequence<I2...>)
|
||||
{
|
||||
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)..., get<I2>(t2)...);
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class T3,
|
||||
std::size_t... I0, std::size_t... I1, std::size_t... I2, std::size_t... I3>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3,
|
||||
std::index_sequence<I0...>, std::index_sequence<I1...>, std::index_sequence<I2...>, std::index_sequence<I3...>)
|
||||
{
|
||||
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)..., get<I2>(t2)..., get<I3>(t3)...);
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class T3, class T4,
|
||||
std::size_t... I0, std::size_t... I1, std::size_t... I2, std::size_t... I3, std::size_t... I4>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4,
|
||||
std::index_sequence<I0...>, std::index_sequence<I1...>, std::index_sequence<I2...>, std::index_sequence<I3...>, std::index_sequence<I4...>)
|
||||
{
|
||||
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)..., get<I2>(t2)..., get<I3>(t3)..., get<I4>(t4)...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple<>
|
||||
tuple_cat()
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class Tuple,
|
||||
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Tuple const&
|
||||
tuple_cat(Tuple const& t)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
|
||||
template <class T0, class T1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1)
|
||||
{
|
||||
return detail::tuple_cat(t0, t1,
|
||||
std::make_index_sequence<std::tuple_size<T0>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T1>::value>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2)
|
||||
{
|
||||
return detail::tuple_cat(t0, t1, t2,
|
||||
std::make_index_sequence<std::tuple_size<T0>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T1>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T2>::value>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class T3>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3)
|
||||
{
|
||||
return detail::tuple_cat(t0, t1, t2, t3,
|
||||
std::make_index_sequence<std::tuple_size<T0>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T1>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T2>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T3>::value>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class T3, class T4>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4)
|
||||
{
|
||||
return detail::tuple_cat(t0, t1, t2, t3, t4,
|
||||
std::make_index_sequence<std::tuple_size<T0>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T1>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T2>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T3>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T4>::value>{});
|
||||
}
|
||||
|
||||
template <class T0, class T1, class T2, class T3, class T4, class T5, class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts)
|
||||
{
|
||||
return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), t5, ts...);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// Outer-Inner indexing trick to concat all tuples at once
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <std::size_t... Ns>
|
||||
struct tuple_cat_helper
|
||||
{
|
||||
static constexpr cute::array<std::size_t,sizeof...(Ns)> ns = {Ns...};
|
||||
|
||||
static constexpr std::size_t total_size() {
|
||||
std::size_t sum = 0;
|
||||
for (std::size_t n : ns) sum += n;
|
||||
return sum;
|
||||
}
|
||||
static constexpr std::size_t total_size_ = total_size();
|
||||
|
||||
static constexpr auto values() {
|
||||
cute::array<std::size_t[2],total_size_> outer_inner = {};
|
||||
|
||||
std::size_t idx = 0;
|
||||
for (std::size_t i = 0; i < ns.size(); ++i) {
|
||||
for (std::size_t j = 0; j < ns[i]; ++j, ++idx) {
|
||||
outer_inner[idx][0] = i;
|
||||
outer_inner[idx][1] = j;
|
||||
}
|
||||
}
|
||||
return outer_inner;
|
||||
}
|
||||
static constexpr auto outer_inner_ = values();
|
||||
|
||||
using total_sequence = std::make_index_sequence<total_size_>;
|
||||
};
|
||||
|
||||
template <class Helper, class Tuple, std::size_t... I>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(Tuple const& t, std::index_sequence<I...>)
|
||||
{
|
||||
return cute::make_tuple(get<Helper::outer_inner_[I][1]>(get<Helper::outer_inner_[I][0]>(t))...);
|
||||
}
|
||||
|
||||
template <class T0, class T1,
|
||||
std::size_t... I0, std::size_t... I1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1,
|
||||
std::index_sequence<I0...>, std::index_sequence<I1...>)
|
||||
{
|
||||
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tuple<>
|
||||
tuple_cat()
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class Tuple,
|
||||
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Tuple const&
|
||||
tuple_cat(Tuple const& t)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
|
||||
template <class T0, class T1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(T0 const& t0, T1 const& t1)
|
||||
{
|
||||
return detail::tuple_cat(t0, t1,
|
||||
std::make_index_sequence<std::tuple_size<T0>::value>{},
|
||||
std::make_index_sequence<std::tuple_size<T1>::value>{});
|
||||
}
|
||||
|
||||
template <class... Tuples>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tuple_cat(Tuples const&... ts)
|
||||
{
|
||||
using Helper = detail::tuple_cat_helper<std::tuple_size<Tuples>::value...>;
|
||||
return detail::tuple_cat<Helper>(make_tuple(ts...), typename Helper::total_sequence{});
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// Equality operators
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <std::size_t I, class TupleA, class TupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
equal_impl(TupleA const& a, TupleB const& b)
|
||||
{
|
||||
if constexpr (I == std::tuple_size<TupleA>::value) {
|
||||
return cute::true_type{}; // Terminal: TupleA is exhausted
|
||||
} else if constexpr (I == std::tuple_size<TupleB>::value) {
|
||||
return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted
|
||||
} else {
|
||||
return (get<I>(a) == get<I>(b)) && equal_impl<I+1>(a,b);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class TupleT, class TupleU,
|
||||
__CUTE_REQUIRES(is_tuple<TupleT>::value && is_tuple<TupleU>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(TupleT const& t, TupleU const& u)
|
||||
{
|
||||
return detail::equal_impl<0>(t, u);
|
||||
}
|
||||
|
||||
template <class TupleT, class TupleU,
|
||||
__CUTE_REQUIRES(is_tuple<TupleT>::value ^ is_tuple<TupleU>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(TupleT const& t, TupleU const& u)
|
||||
{
|
||||
return cute::false_type{};
|
||||
}
|
||||
|
||||
template <class TupleT, class TupleU,
|
||||
__CUTE_REQUIRES(is_tuple<TupleT>::value && is_tuple<TupleU>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator!=(TupleT const& t, TupleU const& u)
|
||||
{
|
||||
return !(t == u);
|
||||
}
|
||||
|
||||
template <class TupleT, class TupleU,
|
||||
__CUTE_REQUIRES(is_tuple<TupleT>::value ^ is_tuple<TupleU>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator!=(TupleT const& t, TupleU const& u)
|
||||
{
|
||||
return cute::true_type{};
|
||||
}
|
||||
|
||||
//
|
||||
// Comparison operators
|
||||
//
|
||||
|
||||
//
|
||||
// There are many ways to compare tuple of elements and because CuTe is built
|
||||
// on parameterizing layouts of coordinates, some comparisons are appropriate
|
||||
// only in certain cases.
|
||||
// -- lexicographical comparison [reverse, reflected, revref]
|
||||
// -- colexicographical comparison [reverse, reflected, revref]
|
||||
// -- element-wise comparison [any,all]
|
||||
// This can be very confusing. To avoid errors in selecting the appropriate
|
||||
// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple.
|
||||
//
|
||||
// That said, see int_tuple for more explicitly named common comparison ops.
|
||||
//
|
||||
|
||||
//
|
||||
// Shortcuts
|
||||
//
|
||||
|
||||
//using std::get;
|
||||
using std::tuple_size;
|
||||
using std::tuple_element;
|
||||
using std::tuple_element_t;
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class Tuple, std::size_t... Is>
|
||||
CUTE_HOST_DEVICE void print_tuple(Tuple const& t,
|
||||
std::index_sequence<Is...>, char s = '(', char e = ')')
|
||||
{
|
||||
using eat = int[];
|
||||
using cute::print;
|
||||
(void) eat {(print(s), 0),
|
||||
(print(Is == 0 ? "" : ","), print(get<Is>(t)), 0)...,
|
||||
(print(e), 0)};
|
||||
}
|
||||
|
||||
template <class Tuple, std::size_t... Is>
|
||||
CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t,
|
||||
std::index_sequence<Is...>, char s = '(', char e = ')')
|
||||
{
|
||||
using eat = int[];
|
||||
(void) eat {(void(os << s), 0),
|
||||
(void(os << (Is == 0 ? "" : ",") << get<Is>(t)), 0)...,
|
||||
(void(os << e), 0)};
|
||||
return os;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class Tuple,
|
||||
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
|
||||
CUTE_HOST_DEVICE void print(Tuple const& t)
|
||||
{
|
||||
return detail::print_tuple(t, std::make_index_sequence<std::tuple_size<Tuple>::value>{});
|
||||
}
|
||||
|
||||
template <class Tuple,
|
||||
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t)
|
||||
{
|
||||
return detail::print_tuple_os(os, t, std::make_index_sequence<std::tuple_size<Tuple>::value>{});
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
//
|
||||
// std:: compatability
|
||||
//
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<cute::tuple<T...>>
|
||||
: std::integral_constant<std::size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <std::size_t I, class... T>
|
||||
struct tuple_element<I, cute::tuple<T...>>
|
||||
: std::tuple_element<I, std::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end std
|
||||
84
include/cute/container/type_list.hpp
Normal file
84
include/cute/container/type_list.hpp
Normal file
@ -0,0 +1,84 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class T>
|
||||
struct type_c {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <class... T>
|
||||
struct type_list {};
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
//
|
||||
// Specialize tuple-related functionality for cute::type_list
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <cute/container/tuple.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <int I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
std::tuple_element_t<I, type_list<T...>>
|
||||
get(type_list<T...>&) noexcept {
|
||||
return {};
|
||||
}
|
||||
template <int I, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
std::tuple_element_t<I, type_list<T...>>
|
||||
get(type_list<T...> const& t) noexcept {
|
||||
return {};
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<cute::type_list<T...>>
|
||||
: std::integral_constant<std::size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <std::size_t I, class... T>
|
||||
struct tuple_element<I, cute::type_list<T...>>
|
||||
: cute::type_c<typename std::tuple_element<I, std::tuple<T...>>::type>
|
||||
{};
|
||||
|
||||
} // end namespace std
|
||||
827
include/cute/int_tuple.hpp
Normal file
827
include/cute/int_tuple.hpp
Normal file
@ -0,0 +1,827 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/container/tuple.hpp>
|
||||
#include <cute/container/array.hpp>
|
||||
#include <cute/algorithm/tuple_algorithms.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class... Ts>
|
||||
using IntTuple = cute::tuple<Ts...>;
|
||||
|
||||
// Construct an IntTuple with all value-elements
|
||||
template <class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
IntTuple<Ts...>
|
||||
make_int_tuple(Ts const&... t)
|
||||
{
|
||||
return {t...};
|
||||
}
|
||||
|
||||
/** if rank(int) == 1, then get<0>(int) should work too
|
||||
*/
|
||||
template <std::size_t I, class T, __CUTE_REQUIRES(is_integral<remove_cvref_t<T>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
get(T&& t) noexcept
|
||||
{
|
||||
static_assert(I == 0, "Index out of range");
|
||||
return static_cast<T&&>(t);
|
||||
}
|
||||
|
||||
/** Custom recursive get for anything that implements get<I>(.)
|
||||
*/
|
||||
template <std::size_t I0, std::size_t I1, std::size_t... Is, class Tuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
get(Tuple&& t) noexcept
|
||||
{
|
||||
return get<I1,Is...>(get<I0>(static_cast<Tuple&&>(t)));
|
||||
}
|
||||
|
||||
//
|
||||
// rank
|
||||
//
|
||||
|
||||
template <int... Is, class IntTuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
rank(IntTuple const& t)
|
||||
{
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
if constexpr (is_tuple<IntTuple>::value) {
|
||||
return Int<tuple_size<IntTuple>::value>{};
|
||||
} else {
|
||||
return Int<1>{};
|
||||
}
|
||||
} else {
|
||||
return rank(get<Is...>(t));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class IntTuple>
|
||||
using rank_t = decltype(rank(std::declval<IntTuple>()));
|
||||
|
||||
template <class IntTuple>
|
||||
static constexpr int rank_v = rank_t<IntTuple>::value;
|
||||
|
||||
//
|
||||
// shape
|
||||
//
|
||||
|
||||
template <class IntTuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
shape(IntTuple const& s)
|
||||
{
|
||||
if constexpr (is_tuple<IntTuple>::value) {
|
||||
return transform(s, [](auto const& a) { return shape(a); });
|
||||
} else {
|
||||
return s;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <int I, int... Is, class IntTuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
shape(IntTuple const& s)
|
||||
{
|
||||
if constexpr (is_tuple<IntTuple>::value) {
|
||||
return shape<Is...>(get<I>(s));
|
||||
} else {
|
||||
return get<I,Is...>(shape(s));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// max
|
||||
//
|
||||
|
||||
template <class T0, class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
max(T0 const& t0, Ts const&... ts)
|
||||
{
|
||||
if constexpr (is_tuple<T0>::value) {
|
||||
return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...);
|
||||
} else if constexpr (sizeof...(Ts) == 0) {
|
||||
return t0;
|
||||
} else {
|
||||
return cute::max(t0, cute::max(ts...));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// min
|
||||
//
|
||||
|
||||
template <class T0, class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
min(T0 const& t0, Ts const&... ts)
|
||||
{
|
||||
if constexpr (is_tuple<T0>::value) {
|
||||
return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...);
|
||||
} else if constexpr (sizeof...(Ts) == 0) {
|
||||
return t0;
|
||||
} else {
|
||||
return cute::min(t0, cute::min(ts...));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// depth
|
||||
//
|
||||
|
||||
template <int... Is, class IntTuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
depth(IntTuple const& t)
|
||||
{
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
if constexpr (is_tuple<IntTuple>::value) {
|
||||
return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); });
|
||||
} else {
|
||||
return Int<0>{};
|
||||
}
|
||||
} else {
|
||||
return depth(get<Is...>(t));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class Tuple>
|
||||
using depth_t = decltype(depth(std::declval<Tuple>()));
|
||||
|
||||
template <class Tuple>
|
||||
static constexpr int depth_v = depth_t<Tuple>::value;
|
||||
|
||||
//
|
||||
// product
|
||||
//
|
||||
|
||||
template <class IntTuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
product(IntTuple const& a)
|
||||
{
|
||||
if constexpr (is_tuple<IntTuple>::value) {
|
||||
return cute::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); });
|
||||
} else {
|
||||
return a;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Product of a subrange
|
||||
template <int B, int E, class Tuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
product(Tuple const& a)
|
||||
{
|
||||
return detail::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); }, make_range<B,E>{});
|
||||
}
|
||||
|
||||
template <class Tuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
product_each(Tuple const& t)
|
||||
{
|
||||
return transform(t, [](auto const& x) { return product(x); });
|
||||
}
|
||||
|
||||
// Return the product of elements in a mode
|
||||
template <int... Is, class IntTuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
size(IntTuple const& a)
|
||||
{
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
return product(a);
|
||||
} else {
|
||||
return product(get<Is...>(a));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class IntTuple>
|
||||
static constexpr int size_v = decltype(size(std::declval<IntTuple>()))::value;
|
||||
|
||||
//
|
||||
// sum
|
||||
//
|
||||
|
||||
template <class IntTuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
sum(IntTuple const& a)
|
||||
{
|
||||
if constexpr (is_tuple<IntTuple>::value) {
|
||||
return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); });
|
||||
} else {
|
||||
return a;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// inner_product
|
||||
//
|
||||
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
inner_product(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
|
||||
static_assert(tuple_size<IntTupleA>::value == tuple_size<IntTupleB>::value, "Mismatched ranks");
|
||||
return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); },
|
||||
[](auto const&... v) { return (Int<0>{} + ... + v); });
|
||||
} else {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// ceil_div
|
||||
//
|
||||
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
ceil_div(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
|
||||
static_assert(tuple_size<IntTupleA>::value >= tuple_size<IntTupleB>::value, "Mismatched ranks");
|
||||
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implictly 1
|
||||
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); });
|
||||
} else {
|
||||
return (a + b - Int<1>{}) / b;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
/** Division for Shapes
|
||||
*/
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
shape_div(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
if constexpr (is_tuple<IntTupleA>::value) {
|
||||
if constexpr (is_tuple<IntTupleB>::value) { // tuple tuple
|
||||
static_assert(tuple_size<IntTupleA>::value == tuple_size<IntTupleB>::value, "Mismatched ranks");
|
||||
return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); });
|
||||
} else { // tuple int
|
||||
auto const [result, rest] = fold(a, make_tuple(make_tuple(), b),
|
||||
[] (auto const& init, auto const& ai) {
|
||||
return make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai));
|
||||
});
|
||||
return result;
|
||||
}
|
||||
} else {
|
||||
if constexpr (is_tuple<IntTupleB>::value) { // int tuple
|
||||
return shape_div(a, product(b));
|
||||
} else { // int int
|
||||
//assert(a % b == 0 || b % a == 0);
|
||||
return a / b != 0 ? a / b : signum(a) * signum(b); // divide with rounding away from zero
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
/** Division for Shapes that are static constants
|
||||
* @pre t % u == 0 || u % t == 0
|
||||
* @result if t % u == 0, then t / u
|
||||
* if u % t == 0, then signum(t) * signum(u)
|
||||
*/
|
||||
template <class T, T t, class U, U u>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<decltype(shape_div(t,u)), shape_div(t,u)>
|
||||
shape_div(constant<T, t> const&, constant<U, u> const&)
|
||||
{
|
||||
static_assert(t % u == 0 || u % t == 0, "Static shape_div failure");
|
||||
return {};
|
||||
}
|
||||
|
||||
/** Return a tuple the same profile as A scaled by corresponding elements in B
|
||||
*/
|
||||
template <class A, class B>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
elem_scale(A const& a, B const& b)
|
||||
{
|
||||
if constexpr (is_tuple<A>::value) {
|
||||
return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); });
|
||||
} else {
|
||||
return a * product(b);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
/** Test if two IntTuple have the same profile (hierarchical rank division)
|
||||
*/
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
congruent(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
return bool_constant<std::is_same<decltype(repeat_like(shape(a),_0{})),
|
||||
decltype(repeat_like(shape(b),_0{}))>::value>{};
|
||||
}
|
||||
|
||||
template <class A, class B>
|
||||
using is_congruent = decltype(congruent(std::declval<A>(), std::declval<B>()));
|
||||
|
||||
/** Test if Shape B is compatible with Shape A:
|
||||
* Any coordinate into A can also be used as a coordinate into B
|
||||
* A <= B is a partially ordered set of factored shapes
|
||||
*/
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compatible(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
|
||||
if constexpr (tuple_size<IntTupleA>::value != tuple_size<IntTupleB>::value) {
|
||||
return false_type{};
|
||||
} else {
|
||||
return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); },
|
||||
[](auto const&... z) { return (true_type{} && ... && z); });
|
||||
}
|
||||
} else if constexpr (is_integral<IntTupleA>::value) {
|
||||
return a == size(b);
|
||||
} else if constexpr (is_integral<IntTupleB>::value) {
|
||||
return false_type{};
|
||||
} else {
|
||||
return compatible(shape(a), shape(b));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class A, class B>
|
||||
using is_compatible = decltype(compatible(std::declval<A>(), std::declval<B>()));
|
||||
|
||||
/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1>
|
||||
*/
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
filter_zeros(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
if constexpr (is_tuple<IntTupleA>::value) {
|
||||
return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); });
|
||||
} else if constexpr (is_constant<0, IntTupleA>::value) {
|
||||
return Int<1>{};
|
||||
} else {
|
||||
return b;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class Tuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
filter_zeros(Tuple const& t)
|
||||
{
|
||||
return filter_zeros(t, t);
|
||||
}
|
||||
|
||||
//
|
||||
// Converters and constructors with arrays and params
|
||||
//
|
||||
|
||||
/** Make an IntTuple of rank N from an Indexable array.
|
||||
* Access elements up to a dynamic index n, then use init (requires compatible types)
|
||||
* Consider cute::take<B,E> if all indexing is known to be valid
|
||||
* \code
|
||||
* std::vector<int> a = {6,3,4};
|
||||
* auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0)
|
||||
* \endcode
|
||||
*/
|
||||
template <int N, class Indexable, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_int_tuple(Indexable const& t, int n, T const& init)
|
||||
{
|
||||
static_assert(N > 0);
|
||||
if constexpr (N == 1) {
|
||||
return 0 < n ? t[0] : init;
|
||||
} else {
|
||||
return transform(make_seq<N>{}, [&](auto i) { return i < n ? t[i] : init; });
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
/** Fill the dynamic values of a Tuple with values from another Tuple
|
||||
* \code
|
||||
* auto params = make_int_tuple(6,3,4);
|
||||
* cute::tuple<Int<1>, cute::tuple<int, int, Int<3>>, int, Int<2>> result;
|
||||
* fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2)
|
||||
* \endcode
|
||||
*/
|
||||
template <class Tuple, class TupleV>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
fill_int_tuple_from(Tuple& result, TupleV const& vals)
|
||||
{
|
||||
return fold(result, vals, [](auto const& init, auto&& r) {
|
||||
if constexpr (is_static<remove_cvref_t<decltype(r)>>::value) { // Skip static elements of result
|
||||
return init;
|
||||
} else if constexpr (is_tuple<remove_cvref_t<decltype(r)>>::value) { // Recurse into tuples
|
||||
return fill_int_tuple_from(r, init);
|
||||
} else { // Assign and consume arg
|
||||
static_assert(tuple_size<remove_cvref_t<decltype(init)>>::value > 0, "Not enough values to fill with!");
|
||||
r = get<0>(init);
|
||||
return remove<0>(init);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
});
|
||||
}
|
||||
|
||||
/** Make a "Tuple" by filling in the dynamic values in order from the arguments
|
||||
* \code
|
||||
* using result_t = cute::tuple<Int<1>, cute::tuple<int, int, Int<3>>, int, Int<2>>;
|
||||
* auto result = make_int_tuple_from<result_t>(6,3,4); // (_1,(6,3,_3),4,_2)
|
||||
* \endcode
|
||||
*/
|
||||
template <class Tuple, class... Ts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Tuple
|
||||
make_int_tuple_from(Ts const&... ts)
|
||||
{
|
||||
Tuple result = Tuple{};
|
||||
fill_int_tuple_from(result, make_tuple(ts...));
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Convert a tuple to a flat homogeneous array of type T
|
||||
* \code
|
||||
* auto tup = make_tuple(Int<1>{}, make_tuple(6,3,Int<3>{}),4,Int<2>{});
|
||||
* cute::array<uint64_t,6> result = to_array<uint64_t>(tup); // [1,6,3,3,4,2]
|
||||
* \endcode
|
||||
*/
|
||||
template <class T = int64_t, class IntTuple>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_array(IntTuple const& t)
|
||||
{
|
||||
auto flat_t = flatten_to_tuple(t);
|
||||
constexpr int N = tuple_size<decltype(flat_t)>::value;
|
||||
cute::array<T,N> result;
|
||||
for_each(make_seq<N>{}, [&] (auto i) { result[i] = get<i>(flat_t); });
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Comparison operators
|
||||
//
|
||||
|
||||
//
|
||||
// There are many ways to compare tuple of elements and because CuTe is built
|
||||
// on parameterizing layouts of coordinates, some comparisons are appropriate
|
||||
// only in certain cases.
|
||||
// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout
|
||||
// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout
|
||||
// -- element-wise comparison [any,all] :
|
||||
// This can be very confusing. To avoid errors in selecting the appropriate
|
||||
// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple.
|
||||
//
|
||||
// When actually desiring to order coordinates, the user should map them to
|
||||
// their indices within the Layout they came from:
|
||||
// e.g. layoutX(coordA) < layoutX(coordB)
|
||||
// That said, we implement the three most common ways to compare tuples below.
|
||||
// These are implemented with slighly more explicit names than op<.
|
||||
//
|
||||
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
lex_less(IntTupleA const& a, IntTupleB const& b);
|
||||
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
colex_less(IntTupleA const& a, IntTupleB const& b);
|
||||
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
elem_less(IntTupleA const& a, IntTupleB const& b);
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <std::size_t I, class TupleA, class TupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
lex_less_impl(TupleA const& a, TupleB const& b)
|
||||
{
|
||||
if constexpr (I == tuple_size<TupleB>::value) {
|
||||
return cute::false_type{}; // Terminal: TupleB is exhausted
|
||||
} else if constexpr (I == tuple_size<TupleA>::value) {
|
||||
return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted
|
||||
} else {
|
||||
return lex_less(get<I>(a), get<I>(b)) || (get<I>(a) == get<I>(b) && lex_less_impl<I+1>(a,b));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <std::size_t I, class TupleA, class TupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
colex_less_impl(TupleA const& a, TupleB const& b)
|
||||
{
|
||||
if constexpr (I == tuple_size<TupleB>::value) {
|
||||
return cute::false_type{}; // Terminal: TupleB is exhausted
|
||||
} else if constexpr (I == tuple_size<TupleA>::value) {
|
||||
return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted
|
||||
} else {
|
||||
constexpr std::size_t A = tuple_size<TupleA>::value - 1 - I;
|
||||
constexpr std::size_t B = tuple_size<TupleB>::value - 1 - I;
|
||||
return colex_less(get<A>(a), get<B>(b)) || (get<A>(a) == get<B>(b) && colex_less_impl<I+1>(a,b));
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <std::size_t I, class TupleA, class TupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
elem_less_impl(TupleA const& a, TupleB const& b)
|
||||
{
|
||||
if constexpr (I == tuple_size<TupleA>::value) {
|
||||
return cute::true_type{}; // Terminal: TupleA is exhausted
|
||||
} else if constexpr (I == tuple_size<TupleB>::value) {
|
||||
return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted
|
||||
} else {
|
||||
return elem_less(get<I>(a), get<I>(b)) && elem_less_impl<I+1>(a,b);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Lexicographical comparison
|
||||
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
lex_less(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
|
||||
return detail::lex_less_impl<0>(a, b);
|
||||
} else {
|
||||
return a < b;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
lex_leq(T const& t, U const& u) {
|
||||
return !lex_less(u, t);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
lex_gtr(T const& t, U const& u) {
|
||||
return lex_less(u, t);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
lex_geq(T const& t, U const& u) {
|
||||
return !lex_less(t, u);
|
||||
}
|
||||
|
||||
// Colexicographical comparison
|
||||
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
colex_less(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
|
||||
return detail::colex_less_impl<0>(a, b);
|
||||
} else {
|
||||
return a < b;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
colex_leq(T const& t, U const& u) {
|
||||
return !colex_less(u, t);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
colex_gtr(T const& t, U const& u) {
|
||||
return colex_less(u, t);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
colex_geq(T const& t, U const& u) {
|
||||
return !colex_less(t, u);
|
||||
}
|
||||
|
||||
// Elementwise [all] comparison
|
||||
|
||||
template <class IntTupleA, class IntTupleB>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
elem_less(IntTupleA const& a, IntTupleB const& b)
|
||||
{
|
||||
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
|
||||
return detail::elem_less_impl<0>(a, b);
|
||||
} else {
|
||||
return a < b;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
elem_leq(T const& t, U const& u) {
|
||||
return !elem_less(u, t);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
elem_gtr(T const& t, U const& u) {
|
||||
return elem_less(u, t);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
elem_geq(T const& t, U const& u) {
|
||||
return !elem_less(t, u);
|
||||
}
|
||||
|
||||
/** Increment a (dynamic) coord lexicographically within a shape
|
||||
* \code
|
||||
* auto shape = make_shape(1,2,make_shape(2,3),3);
|
||||
*
|
||||
* int i = 0;
|
||||
* for (auto coord = repeat_like(shape, 0); back(coord) != back(shape); increment(coord, shape)) {
|
||||
* std::cout << i++ << ": " << coord << std::endl;
|
||||
* }
|
||||
* assert(i == size(shape));
|
||||
* \endcode
|
||||
*/
|
||||
template <class Coord, class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
increment(Coord& coord, Shape const& shape);
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class Coord, class Shape, int I0, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
increment(Coord& coord, Shape const& shape, seq<I0,Is...>)
|
||||
{
|
||||
cute::increment(get<I0>(coord), get<I0>(shape));
|
||||
if constexpr (sizeof...(Is) != 0) {
|
||||
if (back(get<I0>(coord)) == back(get<I0>(shape))) {
|
||||
back(get<I0>(coord)) = 0;
|
||||
increment(coord, shape, seq<Is...>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class Coord, class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
increment(Coord& coord, Shape const& shape)
|
||||
{
|
||||
if constexpr (is_integral<Coord>::value && is_integral<Shape>::value) {
|
||||
++coord;
|
||||
} else if constexpr (is_tuple<Coord>::value && is_tuple<Shape>::value) {
|
||||
static_assert(tuple_size<Coord>::value == tuple_size<Shape>::value, "Mismatched ranks");
|
||||
detail::increment(coord, shape, tuple_seq<Coord>{});
|
||||
} else {
|
||||
static_assert(sizeof(Coord) == 0, "Invalid parameters");
|
||||
}
|
||||
}
|
||||
|
||||
struct ForwardCoordIteratorSentinal
|
||||
{};
|
||||
|
||||
// A forward iterator for a coordinate that starts from zero and goes to shape
|
||||
template <class Coord, class Shape>
|
||||
struct ForwardCoordIterator
|
||||
{
|
||||
static_assert(is_congruent<Coord, Shape>::value);
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Coord const& operator*() const { return coord; }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ForwardCoordIterator& operator++() { increment(coord, shape); return *this; }
|
||||
|
||||
// Sentinal for the end of the implied range
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); }
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator==(ForwardCoordIteratorSentinal const&) const { return back(coord) == back(shape); }
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator!=(ForwardCoordIteratorSentinal const&) const { return back(coord) != back(shape); }
|
||||
// NOTE: These are expensive, avoid use
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator< (ForwardCoordIterator const& other) const { return colex_less(coord, other.coord); }
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; }
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; }
|
||||
|
||||
Coord coord;
|
||||
Shape const& shape;
|
||||
};
|
||||
|
||||
// A forward iterator for a coordinate that starts from zero
|
||||
template <class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_coord_iterator(Shape const& shape)
|
||||
{
|
||||
auto coord = repeat_like(shape, int(0));
|
||||
return ForwardCoordIterator<decltype(coord),Shape>{coord,shape};
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
1638
include/cute/layout.hpp
Normal file
1638
include/cute/layout.hpp
Normal file
File diff suppressed because it is too large
Load Diff
388
include/cute/numeric/arithmetic_tuple.hpp
Normal file
388
include/cute/numeric/arithmetic_tuple.hpp
Normal file
@ -0,0 +1,388 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/container/tuple.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/algorithm/tuple_algorithms.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class... T>
|
||||
struct ArithmeticTuple : tuple<T...>
|
||||
{
|
||||
template <class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTuple(ArithmeticTuple<U...> const& u)
|
||||
: tuple<T...>(static_cast<tuple<U...> const&>(u)) {}
|
||||
|
||||
template <class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTuple(tuple<U...> const& u)
|
||||
: tuple<T...>(u) {}
|
||||
|
||||
template <class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTuple(U const&... u)
|
||||
: tuple<T...>(u...) {}
|
||||
};
|
||||
|
||||
template <class... T>
|
||||
struct is_tuple<ArithmeticTuple<T...>> : true_type {};
|
||||
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_arithmetic_tuple(T const&... t) {
|
||||
return ArithmeticTuple<T...>(t...);
|
||||
}
|
||||
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(tuple<T...> const& t) {
|
||||
return ArithmeticTuple<T...>(t);
|
||||
}
|
||||
|
||||
//
|
||||
// Numeric operators
|
||||
//
|
||||
|
||||
// Addition
|
||||
template <class... T, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ArithmeticTuple<T...> const& t, ArithmeticTuple<U...> const& u) {
|
||||
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
|
||||
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
|
||||
}
|
||||
|
||||
template <class... T, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ArithmeticTuple<T...> const& t, tuple<U...> const& u) {
|
||||
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
|
||||
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
|
||||
}
|
||||
|
||||
template <class... T, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(tuple<T...> const& t, ArithmeticTuple<U...> const& u) {
|
||||
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
|
||||
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
|
||||
}
|
||||
|
||||
//
|
||||
// Special cases
|
||||
//
|
||||
|
||||
template <class T, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(constant<T,0>, ArithmeticTuple<U...> const& u) {
|
||||
return u;
|
||||
}
|
||||
|
||||
template <class... T, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ArithmeticTuple<T...> const& t, constant<U,0>) {
|
||||
return t;
|
||||
}
|
||||
|
||||
//
|
||||
// ArithmeticTupleIterator
|
||||
//
|
||||
|
||||
template <class ArithTuple>
|
||||
struct ArithmeticTupleIterator
|
||||
{
|
||||
ArithTuple coord_;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTupleIterator() : coord_() {}
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithmeticTupleIterator(ArithTuple const& coord) : coord_(coord) {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ArithTuple const& operator*() const { return coord_; }
|
||||
|
||||
template <class Coord>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto operator+(Coord const& c) const {
|
||||
return ArithmeticTupleIterator<decltype(coord_ + c)>(coord_ + c);
|
||||
}
|
||||
|
||||
template <class Coord>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto operator[](Coord const& c) const { return *(*this + c); }
|
||||
};
|
||||
|
||||
template <class ArithTuple>
|
||||
CUTE_HOST_DEVICE void print(ArithmeticTupleIterator<ArithTuple> const& iter) {
|
||||
printf("ArithTuple"); print(iter.coord_);
|
||||
}
|
||||
|
||||
//
|
||||
// ArithmeticTuple "basis" elements
|
||||
//
|
||||
|
||||
// Abstract value:
|
||||
// A ScaledBasis<T,N> is a (at least) rank-N0 ArithmeticTuple:
|
||||
// (_0,_0,...,T,_0,...)
|
||||
|
||||
template <class T, int N>
|
||||
struct ScaledBasis : private tuple<T>
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
ScaledBasis(T const& t = {}) : tuple<T>(t) {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) value() { return get<0>(static_cast<tuple<T> &>(*this)); }
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto) value() const { return get<0>(static_cast<tuple<T> const&>(*this)); }
|
||||
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
auto mode() { return Int<N>{}; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct is_scaled_basis : false_type {};
|
||||
template <class T, int N>
|
||||
struct is_scaled_basis<ScaledBasis<T,N>> : true_type {};
|
||||
|
||||
template <class T, int N>
|
||||
struct is_integral<ScaledBasis<T,N>> : true_type {};
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
basis_value(T const& e) {
|
||||
return e;
|
||||
}
|
||||
|
||||
template <class T, int N>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
basis_value(ScaledBasis<T,N> const& e) {
|
||||
return basis_value(e.value());
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <int... Ns>
|
||||
struct Basis;
|
||||
|
||||
template <>
|
||||
struct Basis<> {
|
||||
using type = Int<1>;
|
||||
};
|
||||
|
||||
template <int N, int... Ns>
|
||||
struct Basis<N,Ns...> {
|
||||
using type = ScaledBasis<typename Basis<Ns...>::type, N>;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <int... N>
|
||||
using E = typename detail::Basis<N...>::type;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, int... I, int... J>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(T const& t, seq<I...>, seq<J...>) {
|
||||
return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...);
|
||||
}
|
||||
|
||||
template <class... T, int... I, int... J>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ArithmeticTuple<T...> const& t, seq<I...>, seq<J...>) {
|
||||
return make_arithmetic_tuple(get<I>(t)..., (void(J),Int<0>{})...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
// Turn a ScaledBases<T,N> into a rank-M ArithmeticTuple
|
||||
// with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0)
|
||||
template <int M, class T, int N>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
|
||||
static_assert(M > N, "Mismatched ranks");
|
||||
return detail::as_arithmetic_tuple(t.value(), make_seq<N>{}, make_seq<M-N-1>{});
|
||||
}
|
||||
|
||||
// Turn an ArithmeticTuple into a rank-M ArithmeticTuple
|
||||
// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0)
|
||||
template <int M, class... T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
as_arithmetic_tuple(ArithmeticTuple<T...> const& t) {
|
||||
static_assert(M >= sizeof...(T), "Mismatched ranks");
|
||||
return detail::as_arithmetic_tuple(t, make_seq<int(sizeof...(T))>{}, make_seq<M-int(sizeof...(T))>{});
|
||||
}
|
||||
|
||||
// Return...
|
||||
template <class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_basis_like(Shape const& shape)
|
||||
{
|
||||
if constexpr (is_integral<Shape>::value) {
|
||||
return Int<1>{};
|
||||
} else {
|
||||
// Generate bases for each rank of shape
|
||||
return transform(tuple_seq<Shape>{}, [&](auto I) {
|
||||
// Generate bases for each rank of shape_i and add an i on front
|
||||
constexpr int i = decltype(I)::value; // NOTE: nvcc workaround
|
||||
return transform_leaf(make_basis_like(get<i>(shape)), [&](auto e) { return ScaledBasis<decltype(e),i>{}; });
|
||||
});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// Equality
|
||||
template <class T, int N, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(ScaledBasis<T,N>, Int<M>) {
|
||||
return false_type{};
|
||||
}
|
||||
|
||||
template <int N, class U, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(Int<N>, ScaledBasis<U,M>) {
|
||||
return false_type{};
|
||||
}
|
||||
|
||||
template <class T, int N, class U, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
|
||||
return bool_constant<M == N>{} && t.value() == u.value();
|
||||
}
|
||||
|
||||
// Multiplication
|
||||
template <class A, int N, class T,
|
||||
__CUTE_REQUIRES(cute::is_integral<A>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator*(A const& a, ScaledBasis<T,N> const& e) {
|
||||
return ScaledBasis<decltype(a*e.value()),N>{a*e.value()};
|
||||
}
|
||||
|
||||
template <int N, class T, class B,
|
||||
__CUTE_REQUIRES(cute::is_integral<B>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator*(ScaledBasis<T,N> const& e, B const& b) {
|
||||
return ScaledBasis<decltype(e.value()*b),N>{e.value()*b};
|
||||
}
|
||||
|
||||
// Addition
|
||||
template <int N, class T, class... U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, ArithmeticTuple<U...> const& u) {
|
||||
constexpr int R = cute::max(N+1, int(sizeof...(U)));
|
||||
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
|
||||
}
|
||||
|
||||
template <class... T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
|
||||
constexpr int R = cute::max(int(sizeof...(T)), M+1);
|
||||
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
|
||||
}
|
||||
|
||||
template <int N, class T, int M, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
|
||||
constexpr int R = cute::max(N+1,M+1);
|
||||
return as_arithmetic_tuple<R>(t) + as_arithmetic_tuple<R>(u);
|
||||
}
|
||||
|
||||
template <class T, class U, int M>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(constant<T,0>, ScaledBasis<U,M> const& u) {
|
||||
return u;
|
||||
}
|
||||
|
||||
template <class T, int N, class U>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator+(ScaledBasis<T,N> const& t, constant<U,0>) {
|
||||
return t;
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
template <class T, int N>
|
||||
CUTE_HOST_DEVICE void print(ScaledBasis<T,N> const& e) {
|
||||
printf("%d:", N); print(e.value());
|
||||
}
|
||||
|
||||
template <class T, int N>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis<T,N> const& e) {
|
||||
return os << N << ":" << e.value();
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
||||
template <class... T>
|
||||
struct tuple_size<cute::ArithmeticTuple<T...>>
|
||||
: std::integral_constant<std::size_t, sizeof...(T)>
|
||||
{};
|
||||
|
||||
template <std::size_t I, class... T>
|
||||
struct tuple_element<I, cute::ArithmeticTuple<T...>>
|
||||
: std::tuple_element<I, std::tuple<T...>>
|
||||
{};
|
||||
|
||||
} // end namespace std
|
||||
51
include/cute/numeric/bfloat.hpp
Normal file
51
include/cute/numeric/bfloat.hpp
Normal file
@ -0,0 +1,51 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <vector_types.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace cute {
|
||||
|
||||
using cutlass::bfloat16_t;
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, bfloat16_t const& v)
|
||||
{
|
||||
return os << float(v);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
163
include/cute/numeric/complex.hpp
Normal file
163
include/cute/numeric/complex.hpp
Normal file
@ -0,0 +1,163 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
//#if defined(__CUDA_ARCH__)
|
||||
//# include <cuda/std/complex>
|
||||
//#else
|
||||
//# include <complex>
|
||||
//#endif
|
||||
|
||||
// With CUDA 11.4, builds show spurious "-Wconversion" warnings
|
||||
// on line 656 of thrust/detail/type_traits.h.
|
||||
// These pragmas suppress the warnings.
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wconversion"
|
||||
#include <thrust/complex.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//#if defined(__CUDA_ARCH__)
|
||||
//template <class T>
|
||||
//using complex = cuda::std::complex<T>;
|
||||
//#else
|
||||
//template <class T>
|
||||
//using complex = std::complex<T>;
|
||||
//#endif
|
||||
|
||||
//template <class T>
|
||||
//using complex = thrust::complex<T>;
|
||||
|
||||
using thrust::complex;
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
T real(complex<T> const& z) {
|
||||
return z.real();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
T imag(complex<T> const& z) {
|
||||
return z.imag();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
complex<T> conj(complex<T> const& z) {
|
||||
return complex<T>(real(z), -imag(z));
|
||||
}
|
||||
|
||||
// cute::conj forwards scalars
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
T conj(T z) {
|
||||
return z;
|
||||
}
|
||||
|
||||
//CUTE_HOST_DEVICE constexpr
|
||||
//float conj(float z) { return z; }
|
||||
//CUTE_HOST_DEVICE constexpr
|
||||
//double conj(double z) { return z; }
|
||||
|
||||
/// Fused multiply-add for complex numbers
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
fma(complex<T> & d,
|
||||
complex<T> const& a,
|
||||
complex<T> const& b,
|
||||
complex<T> const& c)
|
||||
{
|
||||
d.real(c.real() + a.real() * b.real());
|
||||
d.imag(c.imag() + a.real() * b.imag());
|
||||
d.real(d.real() - a.imag() * b.imag());
|
||||
d.imag(d.imag() + a.imag() * b.real());
|
||||
}
|
||||
|
||||
/// Fused multiply-add for triplets
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
fma(complex<T> const& a,
|
||||
complex<T> const& b,
|
||||
complex<T> & c)
|
||||
{
|
||||
return fma(c, a, b, c);
|
||||
}
|
||||
|
||||
/// Used to determine the real-valued underlying type of a numeric type T
|
||||
template <class T>
|
||||
struct RealType {
|
||||
using Type = T;
|
||||
};
|
||||
|
||||
/// Partial specialization for complex-valued type
|
||||
template <class T>
|
||||
struct RealType<complex<T>> {
|
||||
using Type = T;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class T>
|
||||
struct is_complex {
|
||||
static bool const value = false;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct is_complex<complex<T>> {
|
||||
static bool const value = true;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Display utilities
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, complex<T> const& z)
|
||||
{
|
||||
T _r = z.real();
|
||||
T _i = z.imag();
|
||||
|
||||
if (bool(_i)) {
|
||||
return os << _r << "+i" << _i;
|
||||
} else {
|
||||
return os << _r;
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
43
include/cute/numeric/float8.hpp
Normal file
43
include/cute/numeric/float8.hpp
Normal file
@ -0,0 +1,43 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <vector_types.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace cute {
|
||||
|
||||
using cutlass::float_e4m3_t;
|
||||
using cutlass::float_e5m2_t;
|
||||
|
||||
} // end namespace cute
|
||||
41
include/cute/numeric/half.hpp
Normal file
41
include/cute/numeric/half.hpp
Normal file
@ -0,0 +1,41 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <vector_types.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace cute {
|
||||
|
||||
using cutlass::half_t;
|
||||
|
||||
} // end namespace cute
|
||||
129
include/cute/numeric/int.hpp
Normal file
129
include/cute/numeric/int.hpp
Normal file
@ -0,0 +1,129 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <cstdint>
|
||||
#endif
|
||||
|
||||
#include <cute/numeric/integer_subbyte.hpp>
|
||||
#include <cute/numeric/uint128.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Signed integers
|
||||
//
|
||||
|
||||
using int8_t = std::int8_t;
|
||||
using int16_t = std::int16_t;
|
||||
using int32_t = std::int32_t;
|
||||
using int64_t = std::int64_t;
|
||||
|
||||
template <int N> struct int_bit;
|
||||
template <> struct int_bit< 2> { using type = cute::int2b_t; };
|
||||
template <> struct int_bit< 4> { using type = cute::int4b_t; };
|
||||
template <> struct int_bit< 8> { using type = int8_t; };
|
||||
template <> struct int_bit< 16> { using type = int16_t; };
|
||||
template <> struct int_bit< 32> { using type = int32_t; };
|
||||
template <> struct int_bit< 64> { using type = int64_t; };
|
||||
|
||||
template <int N>
|
||||
using int_bit_t = typename int_bit<N>::type;
|
||||
|
||||
template <int N>
|
||||
using int_byte = int_bit<8*N>;
|
||||
|
||||
template <int N>
|
||||
using int_byte_t = typename int_byte<N>::type;
|
||||
|
||||
//
|
||||
// Unsigned integers
|
||||
//
|
||||
|
||||
using uint8_t = std::uint8_t;
|
||||
using uint16_t = std::uint16_t;
|
||||
using uint32_t = std::uint32_t;
|
||||
using uint64_t = std::uint64_t;
|
||||
|
||||
template <int N> struct uint_bit;
|
||||
template <> struct uint_bit< 1> { using type = cute::uint1b_t; };
|
||||
template <> struct uint_bit< 2> { using type = cute::uint2b_t; };
|
||||
template <> struct uint_bit< 4> { using type = cute::uint4b_t; };
|
||||
template <> struct uint_bit< 8> { using type = uint8_t; };
|
||||
template <> struct uint_bit< 16> { using type = uint16_t; };
|
||||
template <> struct uint_bit< 32> { using type = uint32_t; };
|
||||
template <> struct uint_bit< 64> { using type = uint64_t; };
|
||||
template <> struct uint_bit<128> { using type = cute::uint128_t; };
|
||||
|
||||
template <int N>
|
||||
using uint_bit_t = typename uint_bit<N>::type;
|
||||
|
||||
template <int N>
|
||||
using uint_byte = uint_bit<8*N>;
|
||||
|
||||
template <int N>
|
||||
using uint_byte_t = typename uint_byte<N>::type;
|
||||
|
||||
//
|
||||
// sizeof_bytes
|
||||
//
|
||||
|
||||
template <class T>
|
||||
struct sizeof_bytes {
|
||||
static constexpr std::size_t value = sizeof(T);
|
||||
};
|
||||
template <class T>
|
||||
static constexpr int sizeof_bytes_v = sizeof_bytes<T>::value;
|
||||
|
||||
//
|
||||
// sizeof_bits
|
||||
//
|
||||
|
||||
template <class T>
|
||||
struct sizeof_bits {
|
||||
static constexpr std::size_t value = sizeof(T) * 8;
|
||||
};
|
||||
template <>
|
||||
struct sizeof_bits<bool> {
|
||||
static constexpr std::size_t value = 1;
|
||||
};
|
||||
template <int Bits, bool Signed>
|
||||
struct sizeof_bits<integer_subbyte<Bits,Signed>> {
|
||||
static constexpr std::size_t value = Bits;
|
||||
};
|
||||
template <class T>
|
||||
static constexpr int sizeof_bits_v = sizeof_bits<T>::value;
|
||||
|
||||
} // namespace cute
|
||||
139
include/cute/numeric/integer_sequence.hpp
Normal file
139
include/cute/numeric/integer_sequence.hpp
Normal file
@ -0,0 +1,139 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <utility> // std::integer_sequence
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
using std::integer_sequence;
|
||||
using std::make_integer_sequence;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class S, T Begin>
|
||||
struct make_integer_range_impl;
|
||||
|
||||
template <class T, T... N, T Begin>
|
||||
struct make_integer_range_impl<T, integer_sequence<T, N...>, Begin> {
|
||||
using type = integer_sequence<T, N+Begin...>;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class T, T Begin, T End>
|
||||
using make_integer_range = typename detail::make_integer_range_impl<
|
||||
T,
|
||||
make_integer_sequence<T, (End-Begin > 0) ? (End-Begin) : 0>,
|
||||
Begin>::type;
|
||||
|
||||
//
|
||||
// Common aliases
|
||||
//
|
||||
|
||||
// int_sequence
|
||||
|
||||
template <int... Ints>
|
||||
using int_sequence = integer_sequence<int, Ints...>;
|
||||
|
||||
template <int N>
|
||||
using make_int_sequence = make_integer_sequence<int, N>;
|
||||
|
||||
template <int Begin, int End>
|
||||
using make_int_range = make_integer_range<int, Begin, End>;
|
||||
|
||||
// index_sequence
|
||||
|
||||
template <std::size_t... Ints>
|
||||
using index_sequence = integer_sequence<std::size_t, Ints...>;
|
||||
|
||||
template <std::size_t N>
|
||||
using make_index_sequence = make_integer_sequence<std::size_t, N>;
|
||||
|
||||
template <std::size_t Begin, std::size_t End>
|
||||
using make_index_range = make_integer_range<std::size_t, Begin, End>;
|
||||
|
||||
//
|
||||
// Shortcuts
|
||||
//
|
||||
|
||||
template <int... Ints>
|
||||
using seq = int_sequence<Ints...>;
|
||||
|
||||
template <int N>
|
||||
using make_seq = make_int_sequence<N>;
|
||||
|
||||
template <int Min, int Max>
|
||||
using make_range = make_int_range<Min, Max>;
|
||||
|
||||
template <class Tuple>
|
||||
using tuple_seq = make_seq<std::tuple_size<std::remove_reference_t<Tuple>>::value>;
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
|
||||
//
|
||||
// Specialize tuple-related functionality for cute::integer_sequence
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <std::size_t I, class T, T... Ints>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
std::tuple_element_t<I, integer_sequence<T, Ints...>>
|
||||
get(integer_sequence<T, Ints...>) {
|
||||
static_assert(I < sizeof...(Ints), "Index out of range");
|
||||
return {};
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
||||
template <class T, T... Ints>
|
||||
struct tuple_size<cute::integer_sequence<T, Ints...>>
|
||||
: std::integral_constant<std::size_t, sizeof...(Ints)>
|
||||
{};
|
||||
|
||||
template <std::size_t I, class T, T... Ints>
|
||||
struct tuple_element<I, cute::integer_sequence<T, Ints...>>
|
||||
: std::tuple_element<I, std::tuple<cute::integral_constant<T,Ints>...>>
|
||||
{};
|
||||
|
||||
} // end namespace std
|
||||
233
include/cute/numeric/integer_subbyte.hpp
Normal file
233
include/cute/numeric/integer_subbyte.hpp
Normal file
@ -0,0 +1,233 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <cstdint>
|
||||
#endif
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Bits, bool Signed = true>
|
||||
struct integer_subbyte
|
||||
{
|
||||
/// Storage type
|
||||
using Storage = uint8_t;
|
||||
|
||||
/// Number of bits
|
||||
static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte");
|
||||
|
||||
/// External type
|
||||
using xint_t = typename std::conditional<Signed, int, unsigned>::type;
|
||||
|
||||
/// Bitmask for truncation from larger integers
|
||||
static constexpr Storage bits_mask_ = Storage((1 << Bits) - 1);
|
||||
/// Bitmask for the sign bit
|
||||
static constexpr Storage sign_mask_ = Storage((Signed ? 1 : 0) << (Bits - 1));
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
Storage storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// No operation
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
integer_subbyte() {}
|
||||
|
||||
/// Conversion from integer type
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
integer_subbyte(int value) // NOTE: Sign extension?
|
||||
: storage(reinterpret_cast<Storage const&>(value) & bits_mask_) {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
integer_subbyte(unsigned value)
|
||||
: storage(reinterpret_cast<Storage const&>(value) & bits_mask_) {}
|
||||
|
||||
/// Convert to int or unsigned
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator xint_t() const {
|
||||
if (sign_mask_ & storage) { // Sign extend
|
||||
return xint_t(storage) | ~xint_t(bits_mask_);
|
||||
} else {
|
||||
return xint_t(storage);
|
||||
}
|
||||
}
|
||||
|
||||
/// Equality
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator==(integer_subbyte const& rhs) const {
|
||||
return storage == rhs.storage;
|
||||
}
|
||||
|
||||
/// Inequality
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator!=(integer_subbyte const& rhs) const {
|
||||
return storage != rhs.storage;
|
||||
}
|
||||
|
||||
/// Less than or equal
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator<=(integer_subbyte const& rhs) const {
|
||||
if (sign_mask_ & storage) {
|
||||
return !(rhs.storage < storage);
|
||||
} else {
|
||||
return storage < rhs.storage;
|
||||
}
|
||||
}
|
||||
|
||||
/// Less than
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator<(integer_subbyte const& rhs) const {
|
||||
if (sign_mask_ & storage) {
|
||||
return !(rhs.storage <= storage);
|
||||
} else {
|
||||
return storage < rhs.storage;
|
||||
}
|
||||
}
|
||||
|
||||
/// Greater than or equal
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator>=(integer_subbyte const& rhs) const {
|
||||
return !(*this < rhs);
|
||||
}
|
||||
|
||||
/// Greater than
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator>(integer_subbyte const& rhs) const {
|
||||
return !(*this <= rhs);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// 1-bit unsigned integer type
|
||||
using uint1b_t = integer_subbyte<1, false>;
|
||||
|
||||
/// 2-bit integer type
|
||||
using int2b_t = integer_subbyte<2, true>;
|
||||
|
||||
/// 2-bit unsigned integer type
|
||||
using uint2b_t = integer_subbyte<2, false>;
|
||||
|
||||
/// 4-bit integer type
|
||||
using int4b_t = integer_subbyte<4, true>;
|
||||
|
||||
/// 4-bit unsigned integer type
|
||||
using uint4b_t = integer_subbyte<4, false>;
|
||||
|
||||
/// 1-bit binary type
|
||||
using bin1_t = bool;
|
||||
|
||||
} // namespace cute
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
struct numeric_limits<cute::uint1b_t> {
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::uint1b_t const lowest() noexcept { return 0; }
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::uint1b_t const min() noexcept { return 0; }
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::uint1b_t const max() noexcept { return 1; }
|
||||
static constexpr bool is_integer = true;
|
||||
static constexpr bool is_signed = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<cute::int2b_t> {
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::int2b_t lowest() noexcept { return -2; }
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::int2b_t min() noexcept { return -2; }
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::int2b_t max() noexcept { return 1; }
|
||||
static constexpr bool is_integer = true;
|
||||
static constexpr bool is_signed = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<cute::uint2b_t> {
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::uint2b_t const lowest() noexcept { return 0; }
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::uint2b_t const min() noexcept { return 0; }
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::uint2b_t const max() noexcept { return 3; }
|
||||
static constexpr bool is_integer = true;
|
||||
static constexpr bool is_signed = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<cute::int4b_t> {
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::int4b_t lowest() noexcept { return -8; }
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::int4b_t min() noexcept { return -8; }
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::int4b_t max() noexcept { return 7; }
|
||||
static constexpr bool is_integer = true;
|
||||
static constexpr bool is_signed = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<cute::uint4b_t> {
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::uint4b_t const lowest() noexcept { return 0; }
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::uint4b_t const min() noexcept { return 0; }
|
||||
CUTE_HOST_DEVICE static constexpr
|
||||
cute::uint4b_t const max() noexcept { return 15; }
|
||||
static constexpr bool is_integer = true;
|
||||
static constexpr bool is_signed = false;
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
#endif
|
||||
414
include/cute/numeric/integral_constant.hpp
Normal file
414
include/cute/numeric/integral_constant.hpp
Normal file
@ -0,0 +1,414 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cute/numeric/math.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class T, T v>
|
||||
struct constant : std::integral_constant<T,v> {
|
||||
static constexpr T value = v;
|
||||
using value_type = T;
|
||||
using type = constant<T,v>;
|
||||
CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
|
||||
CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
template <class T, T v>
|
||||
using integral_constant = constant<T,v>;
|
||||
|
||||
template <bool b>
|
||||
using bool_constant = constant<bool,b>;
|
||||
|
||||
using true_type = bool_constant<true>;
|
||||
using false_type = bool_constant<false>;
|
||||
|
||||
//
|
||||
// Traits
|
||||
//
|
||||
|
||||
// Use std::is_integral<T> to match built-in integral types (int, int64_t, unsigned, etc)
|
||||
// Use cute::is_integral<T> to match both built-in integral types AND constant<T,t>
|
||||
|
||||
template <class T>
|
||||
struct is_integral : bool_constant<std::is_integral<T>::value> {};
|
||||
template <class T, T v>
|
||||
struct is_integral<constant<T,v>> : true_type {};
|
||||
|
||||
// is_static detects if an (abstract) value is defined completely by it's type (no members)
|
||||
|
||||
template <class T>
|
||||
struct is_static : bool_constant<std::is_empty<T>::value> {};
|
||||
|
||||
// is_constant detects if a type is a constant<T,v> and if v is equal to a value
|
||||
|
||||
template <auto n, class T>
|
||||
struct is_constant : false_type {};
|
||||
template <auto n, class T, T v>
|
||||
struct is_constant<n, constant<T,v> > : bool_constant<v == n> {};
|
||||
template <auto n, class T, T v>
|
||||
struct is_constant<n, constant<T,v> const > : bool_constant<v == n> {};
|
||||
template <auto n, class T, T v>
|
||||
struct is_constant<n, constant<T,v> const&> : bool_constant<v == n> {};
|
||||
template <auto n, class T, T v>
|
||||
struct is_constant<n, constant<T,v> &> : bool_constant<v == n> {};
|
||||
template <auto n, class T, T v>
|
||||
struct is_constant<n, constant<T,v> &&> : bool_constant<v == n> {};
|
||||
|
||||
//
|
||||
// Specializations
|
||||
//
|
||||
|
||||
template <int v>
|
||||
using Int = constant<int,v>;
|
||||
|
||||
using _m32 = Int<-32>;
|
||||
using _m24 = Int<-24>;
|
||||
using _m16 = Int<-16>;
|
||||
using _m12 = Int<-12>;
|
||||
using _m10 = Int<-10>;
|
||||
using _m9 = Int<-9>;
|
||||
using _m8 = Int<-8>;
|
||||
using _m7 = Int<-7>;
|
||||
using _m6 = Int<-6>;
|
||||
using _m5 = Int<-5>;
|
||||
using _m4 = Int<-4>;
|
||||
using _m3 = Int<-3>;
|
||||
using _m2 = Int<-2>;
|
||||
using _m1 = Int<-1>;
|
||||
using _0 = Int<0>;
|
||||
using _1 = Int<1>;
|
||||
using _2 = Int<2>;
|
||||
using _3 = Int<3>;
|
||||
using _4 = Int<4>;
|
||||
using _5 = Int<5>;
|
||||
using _6 = Int<6>;
|
||||
using _7 = Int<7>;
|
||||
using _8 = Int<8>;
|
||||
using _9 = Int<9>;
|
||||
using _10 = Int<10>;
|
||||
using _12 = Int<12>;
|
||||
using _16 = Int<16>;
|
||||
using _24 = Int<24>;
|
||||
using _32 = Int<32>;
|
||||
using _64 = Int<64>;
|
||||
using _96 = Int<96>;
|
||||
using _128 = Int<128>;
|
||||
using _192 = Int<192>;
|
||||
using _256 = Int<256>;
|
||||
using _512 = Int<512>;
|
||||
using _1024 = Int<1024>;
|
||||
using _2048 = Int<2048>;
|
||||
using _4096 = Int<4096>;
|
||||
using _8192 = Int<8192>;
|
||||
|
||||
/***************/
|
||||
/** Operators **/
|
||||
/***************/
|
||||
|
||||
#define CUTE_LEFT_UNARY_OP(OP) \
|
||||
template <class T, T t> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
constant<decltype(OP t), (OP t)> \
|
||||
operator OP (constant<T,t>) { \
|
||||
return {}; \
|
||||
}
|
||||
#define CUTE_RIGHT_UNARY_OP(OP) \
|
||||
template <class T, T t> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
constant<decltype(t OP), (t OP)> \
|
||||
operator OP (constant<T,t>) { \
|
||||
return {}; \
|
||||
}
|
||||
|
||||
#define CUTE_BINARY_OP(OP) \
|
||||
template <class T, T t, class U, U u> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
constant<decltype(t OP u), (t OP u)> \
|
||||
operator OP (constant<T,t>, constant<U,u>) { \
|
||||
return {}; \
|
||||
}
|
||||
|
||||
CUTE_LEFT_UNARY_OP(+);
|
||||
CUTE_LEFT_UNARY_OP(-);
|
||||
CUTE_LEFT_UNARY_OP(~);
|
||||
CUTE_LEFT_UNARY_OP(!);
|
||||
CUTE_LEFT_UNARY_OP(*);
|
||||
|
||||
CUTE_BINARY_OP( +);
|
||||
CUTE_BINARY_OP( -);
|
||||
CUTE_BINARY_OP( *);
|
||||
CUTE_BINARY_OP( /);
|
||||
CUTE_BINARY_OP( %);
|
||||
CUTE_BINARY_OP( &);
|
||||
CUTE_BINARY_OP( |);
|
||||
CUTE_BINARY_OP( ^);
|
||||
CUTE_BINARY_OP(<<);
|
||||
CUTE_BINARY_OP(>>);
|
||||
|
||||
CUTE_BINARY_OP(&&);
|
||||
CUTE_BINARY_OP(||);
|
||||
|
||||
CUTE_BINARY_OP(==);
|
||||
CUTE_BINARY_OP(!=);
|
||||
CUTE_BINARY_OP( >);
|
||||
CUTE_BINARY_OP( <);
|
||||
CUTE_BINARY_OP(>=);
|
||||
CUTE_BINARY_OP(<=);
|
||||
|
||||
#undef CUTE_BINARY_OP
|
||||
#undef CUTE_LEFT_UNARY_OP
|
||||
#undef CUTE_RIGHT_UNARY_OP
|
||||
|
||||
//
|
||||
// Mixed static-dynamic special cases
|
||||
//
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator*(constant<T, 0>, U) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class U, class T,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator*(U, constant<T, 0>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator/(constant<T, 0>, U) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class U, class T,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator%(U, constant<T, 1>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class U, class T,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator%(U, constant<T,-1>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator%(constant<T, 0>, U) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator&(constant<T, 0>, U) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<T, 0>
|
||||
operator&(U, constant<T, 0>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, T t, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value && !bool(t))>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<bool, false>
|
||||
operator&&(constant<T, t>, U) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, T t, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value && !bool(t))>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<bool, false>
|
||||
operator&&(U, constant<T, t>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, class U, T t,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value && bool(t))>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<bool, true>
|
||||
operator||(constant<T, t>, U) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, class U, T t,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value && bool(t))>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<bool, true>
|
||||
operator||(U, constant<T, t>) {
|
||||
return {};
|
||||
}
|
||||
|
||||
//
|
||||
// Named functions from math.hpp
|
||||
//
|
||||
|
||||
#define CUTE_NAMED_UNARY_FN(OP) \
|
||||
template <class T, T t> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
constant<decltype(OP(t)), OP(t)> \
|
||||
OP (constant<T,t>) { \
|
||||
return {}; \
|
||||
}
|
||||
|
||||
#define CUTE_NAMED_BINARY_FN(OP) \
|
||||
template <class T, T t, class U, U u> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
constant<decltype(OP(t,u)), OP(t,u)> \
|
||||
OP (constant<T,t>, constant<U,u>) { \
|
||||
return {}; \
|
||||
} \
|
||||
\
|
||||
template <class T, T t, class U, \
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
auto \
|
||||
OP (constant<T,t>, U u) { \
|
||||
return OP(t,u); \
|
||||
} \
|
||||
\
|
||||
template <class T, class U, U u, \
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value)> \
|
||||
CUTE_HOST_DEVICE constexpr \
|
||||
auto \
|
||||
OP (T t, constant<U,u>) { \
|
||||
return OP(t,u); \
|
||||
}
|
||||
|
||||
CUTE_NAMED_UNARY_FN(abs);
|
||||
CUTE_NAMED_UNARY_FN(signum);
|
||||
CUTE_NAMED_UNARY_FN(has_single_bit);
|
||||
|
||||
CUTE_NAMED_BINARY_FN(max);
|
||||
CUTE_NAMED_BINARY_FN(min);
|
||||
CUTE_NAMED_BINARY_FN(shiftl);
|
||||
CUTE_NAMED_BINARY_FN(shiftr);
|
||||
CUTE_NAMED_BINARY_FN(gcd);
|
||||
CUTE_NAMED_BINARY_FN(lcm);
|
||||
|
||||
#undef CUTE_NAMED_UNARY_FN
|
||||
#undef CUTE_NAMED_BINARY_FN
|
||||
|
||||
//
|
||||
// Other functions
|
||||
//
|
||||
|
||||
template <class T, T t, class U, U u>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
constant<decltype(t / u), t / u>
|
||||
safe_div(constant<T, t>, constant<U, u>) {
|
||||
static_assert(t % u == 0, "Static safe_div requires t % u == 0");
|
||||
return {};
|
||||
}
|
||||
|
||||
template <class T, T t, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(constant<T, t>, U u) {
|
||||
return t / u;
|
||||
}
|
||||
|
||||
template <class T, class U, U u,
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(T t, constant<U, u>) {
|
||||
return t / u;
|
||||
}
|
||||
|
||||
// cute::true_type prefers standard conversion to std::true_type
|
||||
// over user-defined conversion to bool
|
||||
template <class TrueType, class FalseType>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
conditional_return(std::true_type, TrueType&& t, FalseType&&) {
|
||||
return static_cast<TrueType&&>(t);
|
||||
}
|
||||
|
||||
// cute::false_type prefers standard conversion to std::false_type
|
||||
// over user-defined conversion to bool
|
||||
template <class TrueType, class FalseType>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
conditional_return(std::false_type, TrueType&&, FalseType&& f) {
|
||||
return static_cast<FalseType&&>(f);
|
||||
}
|
||||
|
||||
// TrueType and FalseType must have a common type
|
||||
template <class TrueType, class FalseType>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
conditional_return(bool b, TrueType const& t, FalseType const& f) {
|
||||
return b ? t : f;
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
template <class T, T N>
|
||||
CUTE_HOST_DEVICE void print(integral_constant<T,N> const&) {
|
||||
printf("_%d", N);
|
||||
}
|
||||
|
||||
template <class T, T N>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, integral_constant<T,N> const&) {
|
||||
return os << "_" << N;
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
319
include/cute/numeric/math.hpp
Normal file
319
include/cute/numeric/math.hpp
Normal file
@ -0,0 +1,319 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <cstdint>
|
||||
#endif
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Common Operations
|
||||
//
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_arithmetic<T>::value &&
|
||||
std::is_arithmetic<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
max(T const& t, U const& u) {
|
||||
return t < u ? u : t;
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_arithmetic<T>::value &&
|
||||
std::is_arithmetic<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
min(T const& t, U const& u) {
|
||||
return t < u ? t : u;
|
||||
}
|
||||
|
||||
template <class T,
|
||||
__CUTE_REQUIRES(std::is_arithmetic<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
abs(T const& t) {
|
||||
if constexpr (std::is_signed<T>::value) {
|
||||
return t < T(0) ? -t : t;
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// C++17 <numeric> operations
|
||||
//
|
||||
|
||||
// Greatest common divisor of two integers
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value &&
|
||||
std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
gcd(T t, U u) {
|
||||
while (true) {
|
||||
if (t == 0) { return u; }
|
||||
u %= t;
|
||||
if (u == 0) { return t; }
|
||||
t %= u;
|
||||
}
|
||||
}
|
||||
|
||||
// Least common multiple of two integers
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value &&
|
||||
std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
lcm(T const& t, U const& u) {
|
||||
return (t / gcd(t,u)) * u;
|
||||
}
|
||||
|
||||
//
|
||||
// C++20 <bit> operations
|
||||
//
|
||||
|
||||
// Checks if a number is an integral power of two
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool
|
||||
has_single_bit(T x) {
|
||||
return x != 0 && (x & (x - 1)) == 0;
|
||||
}
|
||||
|
||||
// Smallest number of bits needed to represent the given value
|
||||
// bit_width( 0b0000 ) = 0
|
||||
// bit_width( 0b0001 ) = 1
|
||||
// bit_width( 0b0010 ) = 2
|
||||
// bit_width( 0b0011 ) = 2
|
||||
// bit_width( 0b0100 ) = 3
|
||||
// bit_width( 0b0101 ) = 3
|
||||
// bit_width( 0b0110 ) = 3
|
||||
// bit_width( 0b0111 ) = 3
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
bit_width(T x) {
|
||||
static_assert(std::is_unsigned<T>::value, "Only to be used for unsigned types.");
|
||||
constexpr int N = (std::numeric_limits<T>::digits == 64 ? 6 :
|
||||
(std::numeric_limits<T>::digits == 32 ? 5 :
|
||||
(std::numeric_limits<T>::digits == 16 ? 4 :
|
||||
(std::numeric_limits<T>::digits == 8 ? 3 : (assert(false),0)))));
|
||||
T r = 0;
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
T shift = (x > ((T(1) << (T(1) << i))-1)) << i;
|
||||
x >>= shift;
|
||||
r |= shift;
|
||||
}
|
||||
return r + (x != 0);
|
||||
}
|
||||
|
||||
// Smallest integral power of two not less than the given value
|
||||
// bit_ceil( 0b00000000 ) = 0b00000001
|
||||
// bit_ceil( 0b00000001 ) = 0b00000001
|
||||
// bit_ceil( 0b00000010 ) = 0b00000010
|
||||
// bit_ceil( 0b00000011 ) = 0b00000100
|
||||
// bit_ceil( 0b00000100 ) = 0b00000100
|
||||
// bit_ceil( 0b00000101 ) = 0b00001000
|
||||
// bit_ceil( 0b00000110 ) = 0b00001000
|
||||
// bit_ceil( 0b00000111 ) = 0b00001000
|
||||
// bit_ceil( 0b00001000 ) = 0b00001000
|
||||
// bit_ceil( 0b00001001 ) = 0b00010000
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
bit_ceil(T x) {
|
||||
return x == 0 ? T(1) : (T(1) << bit_width(x - 1));
|
||||
}
|
||||
|
||||
// Largest integral power of two not greater than the given value
|
||||
// bit_floor( 0b00000000 ) = 0b00000000
|
||||
// bit_floor( 0b00000001 ) = 0b00000001
|
||||
// bit_floor( 0b00000010 ) = 0b00000010
|
||||
// bit_floor( 0b00000011 ) = 0b00000010
|
||||
// bit_floor( 0b00000100 ) = 0b00000100
|
||||
// bit_floor( 0b00000101 ) = 0b00000100
|
||||
// bit_floor( 0b00000110 ) = 0b00000100
|
||||
// bit_floor( 0b00000111 ) = 0b00000100
|
||||
// bit_floor( 0b00001000 ) = 0b00001000
|
||||
// bit_floor( 0b00001001 ) = 0b00001000
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
bit_floor(T x) {
|
||||
return x == 0 ? 0 : (T(1) << (bit_width(x) - 1));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr T rotl(T x, int s);
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr T rotr(T x, int s);
|
||||
|
||||
// Computes the result of circular bitwise left-rotation
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
rotl(T x, int s) {
|
||||
constexpr int N = std::numeric_limits<T>::digits;
|
||||
return s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s);
|
||||
}
|
||||
|
||||
// Computes the result of circular bitwise right-rotation
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
rotr(T x, int s) {
|
||||
constexpr int N = std::numeric_limits<T>::digits;
|
||||
return s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s);
|
||||
}
|
||||
|
||||
// Counts the number of consecutive 0 bits, starting from the most significant bit
|
||||
// countl_zero( 0b00000000 ) = 8
|
||||
// countl_zero( 0b11111111 ) = 0
|
||||
// countl_zero( 0b00011100 ) = 3
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
countl_zero(T x) {
|
||||
return std::numeric_limits<T>::digits - bit_width(x);
|
||||
}
|
||||
|
||||
// Counts the number of consecutive 1 bits, starting from the most significant bit
|
||||
// countl_one( 0b00000000 ) = 0
|
||||
// countl_one( 0b11111111 ) = 8
|
||||
// countl_one( 0b11100011 ) = 3
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
countl_one(T x) {
|
||||
return countl_zero(~x);
|
||||
}
|
||||
|
||||
// Counts the number of consecutive 0 bits, starting from the least significant bit
|
||||
// countr_zero( 0b00000000 ) = 8
|
||||
// countr_zero( 0b11111111 ) = 0
|
||||
// countr_zero( 0b00011100 ) = 2
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
countr_zero(T x) {
|
||||
return x == 0 ? std::numeric_limits<T>::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB
|
||||
}
|
||||
|
||||
// Counts the number of consecutive 1 bits, starting from the least significant bit
|
||||
// countr_one( 0b00000000 ) = 0
|
||||
// countr_one( 0b11111111 ) = 8
|
||||
// countr_one( 0b11100011 ) = 2
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
countr_one(T x) {
|
||||
return countr_zero(~x);
|
||||
}
|
||||
|
||||
// Counts the number of 1 bits in an unsigned integer
|
||||
// popcount( 0b00000000 ) = 0
|
||||
// popcount( 0b11111111 ) = 8
|
||||
// popcount( 0b00011101 ) = 4
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
int
|
||||
popcount(T x) {
|
||||
int c = 0;
|
||||
while (x) {
|
||||
++c;
|
||||
x &= x - 1; // clear the least significant bit set
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
//
|
||||
// Custom operations
|
||||
//
|
||||
|
||||
// Computes the result of bitwise left-shift
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
shiftl(T x, int s) {
|
||||
return s >= 0 ? (x << s) : (x >> -s);
|
||||
}
|
||||
|
||||
// Computes the result of bitwise right-shift
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T
|
||||
shiftr(T x, int s) {
|
||||
return s >= 0 ? (x >> s) : (x << -s);
|
||||
}
|
||||
|
||||
// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero.
|
||||
template <class T,
|
||||
__CUTE_REQUIRES(std::is_unsigned<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
int
|
||||
signum(T const& x) {
|
||||
return T(0) < x;
|
||||
}
|
||||
|
||||
template <class T,
|
||||
__CUTE_REQUIRES(not std::is_unsigned<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
int
|
||||
signum(T const& x) {
|
||||
return (T(0) < x) - (x < T(0));
|
||||
}
|
||||
|
||||
// Safe divide
|
||||
// @pre t % u == 0
|
||||
// @result t / u
|
||||
template <class T, class U,
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value &&
|
||||
std::is_integral<U>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(T const& t, U const& u) {
|
||||
//assert(t % u == 0);
|
||||
return t / u;
|
||||
}
|
||||
|
||||
} // namespace cute
|
||||
56
include/cute/numeric/real.hpp
Normal file
56
include/cute/numeric/real.hpp
Normal file
@ -0,0 +1,56 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
/// Generic fused multiply-add
|
||||
template <class D, class A, class B, class C>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
fma(D& d, A const& a, B const& b, C const& c)
|
||||
{
|
||||
d = a * b + c;
|
||||
}
|
||||
|
||||
/// Fused multiply-add for triplets
|
||||
template <class A, class B, class C>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
void
|
||||
fma(A const& a, B const& b, C& c)
|
||||
{
|
||||
return fma(c, a, b, c);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
51
include/cute/numeric/tfloat.hpp
Normal file
51
include/cute/numeric/tfloat.hpp
Normal file
@ -0,0 +1,51 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <vector_types.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace cute {
|
||||
|
||||
using cutlass::tfloat32_t;
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, tfloat32_t const& v)
|
||||
{
|
||||
return os << float(v);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
259
include/cute/numeric/uint128.hpp
Normal file
259
include/cute/numeric/uint128.hpp
Normal file
@ -0,0 +1,259 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <cmath>
|
||||
#include <type_traits>
|
||||
#include <stdexcept>
|
||||
#endif
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
/// Optionally enable GCC's built-in type
|
||||
#if defined(__x86_64) && !defined(__CUDA_ARCH__)
|
||||
# if defined(__GNUC__) && 0
|
||||
# define CUTE_UINT128_NATIVE
|
||||
# elif defined(_MSC_VER)
|
||||
# define CUTE_INT128_ARITHMETIC
|
||||
# include <intrin.h>
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cute {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///! Unsigned 128b integer type
|
||||
struct alignas(16) uint128_t
|
||||
{
|
||||
/// Size of one part of the uint's storage in bits
|
||||
static constexpr int storage_bits_ = 64;
|
||||
|
||||
struct hilo
|
||||
{
|
||||
uint64_t lo;
|
||||
uint64_t hi;
|
||||
};
|
||||
|
||||
// Use a union to store either low and high parts or, if present, a built-in 128b integer type.
|
||||
union
|
||||
{
|
||||
struct hilo hilo_;
|
||||
|
||||
#if defined(CUTE_UINT128_NATIVE)
|
||||
unsigned __int128 native;
|
||||
#endif // defined(CUTE_UINT128_NATIVE)
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint128_t() : hilo_{0, 0} {}
|
||||
|
||||
/// Constructor from uint64
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint128_t(uint64_t lo_) : hilo_{lo_, 0} {}
|
||||
|
||||
/// Constructor from two 64b unsigned integers
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint128_t(uint64_t lo_, uint64_t hi_) : hilo_{lo_, hi_} {}
|
||||
|
||||
/// Optional constructor from native value
|
||||
#if defined(CUTE_UINT128_NATIVE)
|
||||
uint128_t(unsigned __int128 value) : native(value) { }
|
||||
#endif
|
||||
|
||||
/// Lossily cast to uint64
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
explicit operator uint64_t() const
|
||||
{
|
||||
return hilo_.lo;
|
||||
}
|
||||
|
||||
template <class Dummy = bool>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
static void exception()
|
||||
{
|
||||
//static_assert(sizeof(Dummy) == 0, "Not implemented exception!");
|
||||
//abort();
|
||||
//printf("uint128 not implemented!\n");
|
||||
}
|
||||
|
||||
/// Add
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint128_t operator+(uint128_t const& rhs) const
|
||||
{
|
||||
uint128_t y;
|
||||
#if defined(CUTE_UINT128_NATIVE)
|
||||
y.native = native + rhs.native;
|
||||
#else
|
||||
y.hilo_.lo = hilo_.lo + rhs.hilo_.lo;
|
||||
y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (!y.hilo_.lo && (rhs.hilo_.lo));
|
||||
#endif
|
||||
return y;
|
||||
}
|
||||
|
||||
/// Subtract
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint128_t operator-(uint128_t const& rhs) const
|
||||
{
|
||||
uint128_t y;
|
||||
#if defined(CUTE_UINT128_NATIVE)
|
||||
y.native = native - rhs.native;
|
||||
#else
|
||||
y.hilo_.lo = hilo_.lo - rhs.hilo_.lo;
|
||||
y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo);
|
||||
#endif
|
||||
return y;
|
||||
}
|
||||
|
||||
/// Multiply by unsigned 64b integer yielding 128b integer
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint128_t operator*(uint64_t const& rhs) const
|
||||
{
|
||||
uint128_t y;
|
||||
#if defined(CUTE_UINT128_NATIVE)
|
||||
y.native = native * rhs;
|
||||
#elif defined(CUTE_INT128_ARITHMETIC)
|
||||
// Multiply by the low part
|
||||
y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi);
|
||||
|
||||
// Add the high part and ignore the overflow
|
||||
uint64_t overflow;
|
||||
y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow);
|
||||
#else
|
||||
exception();
|
||||
#endif
|
||||
return y;
|
||||
}
|
||||
|
||||
/// Divide 128b operation by 64b operation yielding a 64b quotient
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint64_t operator/(uint64_t const& divisor) const
|
||||
{
|
||||
uint64_t quotient = 0;
|
||||
#if defined(CUTE_UINT128_NATIVE)
|
||||
quotient = uint64_t(native / divisor);
|
||||
#elif defined(CUTE_INT128_ARITHMETIC)
|
||||
// implemented using MSVC's arithmetic intrinsics
|
||||
uint64_t remainder = 0;
|
||||
quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder);
|
||||
#else
|
||||
exception();
|
||||
#endif
|
||||
return quotient;
|
||||
}
|
||||
|
||||
/// Divide 128b operation by 64b operation yielding a 64b quotient
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint64_t operator%(uint64_t const& divisor) const
|
||||
{
|
||||
uint64_t remainder = 0;
|
||||
#if defined(CUTE_UINT128_NATIVE)
|
||||
remainder = uint64_t(native % divisor);
|
||||
#elif defined(CUTE_INT128_ARITHMETIC)
|
||||
// implemented using MSVC's arithmetic intrinsics
|
||||
(void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder);
|
||||
#else
|
||||
exception();
|
||||
#endif
|
||||
return remainder;
|
||||
}
|
||||
|
||||
/// Computes the quotient and remainder in a single method.
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint64_t divmod(uint64_t &remainder, uint64_t divisor) const
|
||||
{
|
||||
uint64_t quotient = 0;
|
||||
#if defined(CUTE_UINT128_NATIVE)
|
||||
quotient = uint64_t(native / divisor);
|
||||
remainder = uint64_t(native % divisor);
|
||||
#elif defined(CUTE_INT128_ARITHMETIC)
|
||||
// implemented using MSVC's arithmetic intrinsics
|
||||
quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder);
|
||||
#else
|
||||
exception();
|
||||
#endif
|
||||
return quotient;
|
||||
}
|
||||
|
||||
/// Left-shifts a 128b unsigned integer
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint128_t operator<<(int sh) const
|
||||
{
|
||||
if (sh == 0) {
|
||||
return *this;
|
||||
}
|
||||
else if (sh >= storage_bits_) {
|
||||
return uint128_t(0, hilo_.lo << (sh - storage_bits_));
|
||||
}
|
||||
else {
|
||||
return uint128_t(
|
||||
(hilo_.lo << sh),
|
||||
(hilo_.hi << sh) | uint64_t(hilo_.lo >> (storage_bits_ - sh))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Right-shifts a 128b unsigned integer
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint128_t operator>>(int sh) const
|
||||
{
|
||||
if (sh == 0) {
|
||||
return *this;
|
||||
}
|
||||
else if (sh >= storage_bits_) {
|
||||
return uint128_t((hilo_.hi >> (sh - storage_bits_)), 0);
|
||||
}
|
||||
else {
|
||||
return uint128_t(
|
||||
(hilo_.lo >> sh) | (hilo_.hi << (storage_bits_ - sh)),
|
||||
(hilo_.hi >> sh)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cute
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
322
include/cute/pointer.hpp
Normal file
322
include/cute/pointer.hpp
Normal file
@ -0,0 +1,322 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/numeric/math.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// has_dereference to determine if a type is a pointer concept
|
||||
//
|
||||
|
||||
template <class T, class = void>
|
||||
struct has_dereference : std::false_type {
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct has_dereference<T, void_t<decltype(*std::declval<T>())>> : std::true_type {
|
||||
};
|
||||
|
||||
//
|
||||
// Pointer categories
|
||||
//
|
||||
|
||||
template <class T>
|
||||
struct is_gmem : false_type {};
|
||||
|
||||
template <class T>
|
||||
struct is_smem : false_type {};
|
||||
|
||||
// Anything that is not gmem or smem is rmem
|
||||
template <class T>
|
||||
struct is_rmem : bool_constant< not (is_gmem<T>::value || is_smem<T>::value)> {};
|
||||
|
||||
//
|
||||
// A very simplified wrapper for pointers -- use for constructing tagged pointers
|
||||
//
|
||||
template <class T, class DerivedType>
|
||||
struct device_ptr
|
||||
{
|
||||
using value_type = T;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
device_ptr(T* ptr) : ptr_(ptr) {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T* get() const { return ptr_; }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T& operator*() const { return *ptr_; }
|
||||
|
||||
template <class Index>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T& operator[](Index const& i) const { return ptr_[i]; }
|
||||
|
||||
template <class Index>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
DerivedType operator+(Index const& i) const { return {ptr_ + i}; }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr friend
|
||||
std::ptrdiff_t operator-(device_ptr<T,DerivedType> const& a,
|
||||
device_ptr<T,DerivedType> const& b) {
|
||||
return a.ptr_ - b.ptr_;
|
||||
}
|
||||
|
||||
T* ptr_;
|
||||
};
|
||||
|
||||
//
|
||||
// gmem_ptr
|
||||
//
|
||||
|
||||
template <class T>
|
||||
struct gmem_ptr : device_ptr<T, gmem_ptr<T>> {
|
||||
using device_ptr<T, gmem_ptr<T>>::device_ptr;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
gmem_ptr<T>
|
||||
make_gmem_ptr(T* ptr) {
|
||||
return {ptr};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
gmem_ptr<T>
|
||||
make_gmem_ptr(void* ptr) {
|
||||
return {reinterpret_cast<T*>(ptr)};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct is_gmem<gmem_ptr<T>> : true_type {};
|
||||
|
||||
//
|
||||
// smem_ptr
|
||||
//
|
||||
|
||||
template <class T>
|
||||
struct smem_ptr : device_ptr<T, smem_ptr<T>> {
|
||||
using device_ptr<T, smem_ptr<T>>::device_ptr;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
smem_ptr<T>
|
||||
make_smem_ptr(T* ptr) {
|
||||
return {ptr};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
smem_ptr<T>
|
||||
make_smem_ptr(void* ptr) {
|
||||
return {reinterpret_cast<T*>(ptr)};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct is_smem<smem_ptr<T>> : true_type {};
|
||||
|
||||
//
|
||||
// rmem_ptr
|
||||
//
|
||||
|
||||
template <class T>
|
||||
struct rmem_ptr : device_ptr<T, rmem_ptr<T>> {
|
||||
using device_ptr<T, rmem_ptr<T>>::device_ptr;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
rmem_ptr<T>
|
||||
make_rmem_ptr(T* ptr) {
|
||||
return {ptr};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
rmem_ptr<T>
|
||||
make_rmem_ptr(void* ptr) {
|
||||
return {reinterpret_cast<T*>(ptr)};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct is_rmem<rmem_ptr<T>> : true_type {};
|
||||
|
||||
//
|
||||
// counting iterator -- quick and dirty
|
||||
//
|
||||
|
||||
struct counting
|
||||
{
|
||||
using index_type = int;
|
||||
using value_type = index_type;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
counting() : n_(0) {}
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
counting(index_type const& n) : n_(n) {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
index_type operator[](index_type const& i) const { return n_ + i; }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
index_type const& operator*() const { return n_; }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
counting operator+(index_type const& i) const { return {n_ + i}; }
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
counting& operator++() { ++n_; return *this; }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator==(counting const& other) const { return n_ == other.n_; }
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator!=(counting const& other) const { return n_ != other.n_; }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
bool operator< (counting const& other) const { return n_ < other.n_; }
|
||||
|
||||
index_type n_;
|
||||
};
|
||||
|
||||
//
|
||||
// recast
|
||||
//
|
||||
|
||||
template <class NewT, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(T* ptr) {
|
||||
return reinterpret_cast<NewT*>(ptr);
|
||||
}
|
||||
|
||||
template <class NewT, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(T const* ptr) {
|
||||
return reinterpret_cast<NewT const*>(ptr);
|
||||
}
|
||||
|
||||
template <class NewT, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(gmem_ptr<T> const& ptr) {
|
||||
return make_gmem_ptr(recast<NewT>(ptr.ptr_));
|
||||
}
|
||||
|
||||
template <class NewT, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(gmem_ptr<T const> const& ptr) {
|
||||
return make_gmem_ptr(recast<NewT const>(ptr.ptr_));
|
||||
}
|
||||
|
||||
template <class NewT, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(smem_ptr<T> const& ptr) {
|
||||
return make_smem_ptr(recast<NewT>(ptr.ptr_));
|
||||
}
|
||||
|
||||
template <class NewT, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(smem_ptr<T const> const& ptr) {
|
||||
return make_smem_ptr(recast<NewT const>(ptr.ptr_));
|
||||
}
|
||||
|
||||
template <class NewT, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(rmem_ptr<T> const& ptr) {
|
||||
return make_rmem_ptr(recast<NewT>(ptr.ptr_));
|
||||
}
|
||||
|
||||
template <class NewT, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(rmem_ptr<T const> const& ptr) {
|
||||
return make_rmem_ptr(recast<NewT const>(ptr.ptr_));
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void print(T const* const ptr)
|
||||
{
|
||||
printf("raw_ptr_%db(%p)", int(8*sizeof(T)), ptr);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void print(gmem_ptr<T> const& ptr)
|
||||
{
|
||||
printf("gmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get());
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void print(smem_ptr<T> const& ptr)
|
||||
{
|
||||
printf("smem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get());
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void print(rmem_ptr<T> const& ptr)
|
||||
{
|
||||
printf("rmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get());
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr<T> const& ptr)
|
||||
{
|
||||
return os << "gmem_ptr_" << int(8*sizeof(T)) << "b";
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr<T> const& ptr)
|
||||
{
|
||||
return os << "smem_ptr_" << int(8*sizeof(T)) << "b";
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr<T> const& ptr)
|
||||
{
|
||||
return os << "rmem_ptr_" << int(8*sizeof(T)) << "b";
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
411
include/cute/stride.hpp
Normal file
411
include/cute/stride.hpp
Normal file
@ -0,0 +1,411 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/int_tuple.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
/** crd2idx maps a coordinate within <Shape,Stride> to an index
|
||||
* This is computed as follows:
|
||||
* [coord, shape, and stride are all integers => step forward by stride]
|
||||
* op(c, s, d) => c * d
|
||||
* [coord is integer, shape and stride are tuple => divmod coord for each mode]
|
||||
* op(c, (s,S), (d,D)) => op(c % prod(s), s, d) + op(c / prod(s), (S), (D))
|
||||
* [coord, shape, and stride are all tuples => consider each mode independently]
|
||||
* op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D))
|
||||
*/
|
||||
|
||||
template <class Coord, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
crd2idx(Coord const& coord,
|
||||
Shape const& shape,
|
||||
Stride const& stride);
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class Coord, class Shape, class Stride, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
crd2idx_ttt(Coord const& coord,
|
||||
Shape const& shape,
|
||||
Stride const& stride, seq<Is...>)
|
||||
{
|
||||
return (... + crd2idx(get<Is>(coord), get<Is>(shape), get<Is>(stride)));
|
||||
}
|
||||
|
||||
template <class CInt, class STuple, class DTuple, int I0, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
crd2idx_itt(CInt const& coord,
|
||||
STuple const& shape,
|
||||
DTuple const& stride, seq<I0,Is...>)
|
||||
{
|
||||
if constexpr (sizeof...(Is) == 0) { // Avoid recursion and mod on single/last iter
|
||||
return crd2idx(coord, get<I0>(shape), get<I0>(stride));
|
||||
} else { // General case
|
||||
return crd2idx(coord % product(get<I0>(shape)), get<I0>(shape), get<I0>(stride))
|
||||
+ crd2idx_itt(coord / product(get<I0>(shape)), shape, stride, seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class Coord, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
crd2idx(Coord const& coord,
|
||||
Shape const& shape,
|
||||
Stride const& stride)
|
||||
{
|
||||
if constexpr (is_tuple<Coord>::value) {
|
||||
if constexpr (is_tuple<Shape>::value) { // tuple tuple tuple
|
||||
static_assert(tuple_size<Coord>::value == tuple_size< Shape>::value, "Mismatched Ranks");
|
||||
static_assert(tuple_size<Coord>::value == tuple_size<Stride>::value, "Mismatched Ranks");
|
||||
return detail::crd2idx_ttt(coord, shape, stride, tuple_seq<Coord>{});
|
||||
} else { // tuple "int" "int"
|
||||
static_assert(sizeof(Coord) == 0, "Invalid parameters");
|
||||
}
|
||||
} else {
|
||||
if constexpr (is_tuple<Shape>::value) { // "int" tuple tuple
|
||||
static_assert(tuple_size<Shape>::value == tuple_size<Stride>::value, "Mismatched Ranks");
|
||||
return detail::crd2idx_itt(coord, shape, stride, tuple_seq<Shape>{});
|
||||
} else { // "int" "int" "int"
|
||||
return coord * stride;
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// If we know Stride is default [CompactColMajor], then we can take shortcuts
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class CTuple, class STuple, int I0, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
crd2idx_horner(CTuple const& coord,
|
||||
STuple const& shape, seq<I0,Is...>)
|
||||
{
|
||||
if constexpr (sizeof...(Is) == 0) { // No recursion on single/last iter
|
||||
return get<I0>(coord);
|
||||
} else { // General case
|
||||
return get<I0>(coord) + get<I0>(shape) * crd2idx_horner(coord, shape, seq<Is...>{});
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class Coord, class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
crd2idx(Coord const& coord,
|
||||
Shape const& shape)
|
||||
{
|
||||
static_assert(decltype(congruent(coord,shape))::value, "Mismatched Ranks");
|
||||
if constexpr (is_tuple<Shape>::value) {
|
||||
// Flatten and apply Horner's method
|
||||
auto flat_coord = flatten(coord);
|
||||
auto flat_shape = flatten(shape);
|
||||
return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq<decltype(flat_shape)>{});
|
||||
} else {
|
||||
return coord;
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
/** idx2crd splits an index to a coordinate within <Shape,Stride>.
|
||||
*
|
||||
* This is computed as follows:
|
||||
* [index, shape, and stride are all integers => determine 1D coord]
|
||||
* op(i, s, d) => (i / d) % s
|
||||
* [index is integer, shape and stride are tuple => determine component for each mode]
|
||||
* op(i, (s,S), (d,D)) => (op(i, s, d), op(i, S, D)...)
|
||||
* [index, shape, and stride are all tuples => consider each mode independently]
|
||||
* op((i,I), (s,S), (d,D)) => (op(i, s, d), op((I), (S), (D)))
|
||||
*
|
||||
* NOTE: This only works for compact shape+stride layouts. A more general version would
|
||||
* apply to all surjective layouts
|
||||
*/
|
||||
|
||||
template <class Index, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
idx2crd(Index const& idx,
|
||||
Shape const& shape,
|
||||
Stride const& stride)
|
||||
{
|
||||
if constexpr (is_tuple<Index>::value) {
|
||||
if constexpr (is_tuple<Shape>::value) { // tuple tuple tuple
|
||||
static_assert(tuple_size<Index>::value == tuple_size< Shape>::value, "Mismatched Ranks");
|
||||
static_assert(tuple_size<Index>::value == tuple_size<Stride>::value, "Mismatched Ranks");
|
||||
return transform(idx, shape, stride, [](auto const& i, auto const& s, auto const& d){ return idx2crd(i,s,d); });
|
||||
} else { // tuple "int" "int"
|
||||
static_assert(sizeof(Index) == 0, "Invalid parameters");
|
||||
}
|
||||
} else {
|
||||
if constexpr (is_tuple<Shape>::value) {
|
||||
if constexpr (is_tuple<Stride>::value) { // "int" tuple tuple
|
||||
static_assert(tuple_size<Shape>::value == tuple_size<Stride>::value, "Mismatched Ranks");
|
||||
return transform(shape, stride, [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); });
|
||||
} else { // "int" tuple "int"
|
||||
return transform(shape, compact_col_major(shape, stride), [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); });
|
||||
}
|
||||
} else { // "int" "int" "int"
|
||||
return (idx / stride) % shape;
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// If we know Stride is default [CompactColMajor], then we can take shortcuts
|
||||
//
|
||||
|
||||
//(idx / 1) % s0
|
||||
//(idx / s0) % s1
|
||||
//(idx / (s0 * s1)) % s2
|
||||
//...
|
||||
|
||||
template <class Index, class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
idx2crd(Index const& idx,
|
||||
Shape const& shape)
|
||||
{
|
||||
if constexpr (is_tuple<Index>::value) {
|
||||
if constexpr (is_tuple<Shape>::value) { // tuple tuple
|
||||
static_assert(tuple_size<Index>::value == tuple_size<Shape>::value, "Mismatched Ranks");
|
||||
return transform(idx, shape, [](auto const& i, auto const& s) { return idx2crd(i,s); });
|
||||
} else { // tuple "int"
|
||||
static_assert(sizeof(Index) == 0, "Invalid parameters");
|
||||
}
|
||||
} else {
|
||||
if constexpr (is_tuple<Shape>::value) { // "int" tuple
|
||||
return idx2crd(idx, shape, compact_col_major(shape));
|
||||
} else { // "int" "int"
|
||||
return idx;
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// crd2crd
|
||||
//
|
||||
|
||||
template <class Coord, class SShape, class DShape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
crd2crd(Coord const& coord,
|
||||
SShape const& src_shape,
|
||||
DShape const& dst_shape)
|
||||
{
|
||||
if constexpr (is_tuple<Coord>::value && is_tuple<SShape>::value && is_tuple<DShape>::value) {
|
||||
static_assert(tuple_size<Coord>::value == tuple_size<SShape>::value, "Mismatched Ranks");
|
||||
static_assert(tuple_size<Coord>::value == tuple_size<DShape>::value, "Mismatched Ranks");
|
||||
return transform(coord, src_shape, dst_shape, [](auto const& c, auto const& s, auto const& d) { return crd2crd(c,s,d); });
|
||||
} else {
|
||||
// assert(size(src_shape) == size(dst_shape))
|
||||
return idx2crd(crd2idx(coord, src_shape), dst_shape);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Compact Major
|
||||
//
|
||||
|
||||
// General tag for common layouts and dispatching
|
||||
struct GenColMajor {};
|
||||
struct GenRowMajor {};
|
||||
|
||||
template <class Shape, class Current = Int<1>, class Major = GenColMajor>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compact_major(Shape const& shape,
|
||||
Current const& current = {},
|
||||
Major const& major = {});
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class Shape, class Current, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compact_major_ti(Shape const& shape,
|
||||
Current const& current,
|
||||
GenColMajor const& major, seq<Is...>)
|
||||
{
|
||||
return cute::make_tuple(compact_major(get<Is>(shape), current * product<0,Is>(shape), major)...);
|
||||
}
|
||||
|
||||
template <class Shape, class Current, int... Is>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compact_major_ti(Shape const& shape,
|
||||
Current const& current,
|
||||
GenRowMajor const& major, seq<Is...>)
|
||||
{
|
||||
constexpr int E = tuple_size<Shape>::value;
|
||||
return cute::make_tuple(compact_major(get<Is>(shape), current * product<Is+1,E>(shape), major)...);
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class Shape, class Current, class Major>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compact_major(Shape const& shape,
|
||||
Current const& current,
|
||||
Major const& major)
|
||||
{
|
||||
if constexpr (is_tuple<Shape>::value) {
|
||||
if constexpr (is_tuple<Current>::value) { // tuple tuple
|
||||
static_assert(tuple_size<Shape>::value == tuple_size<Current>::value, "Mismatched Ranks");
|
||||
return transform(shape, current, [&](auto const& s, auto const& c){ return compact_major(s,c,major); });
|
||||
} else { // tuple int
|
||||
return detail::compact_major_ti(shape, current, major, tuple_seq<Shape>{});
|
||||
}
|
||||
} else {
|
||||
if constexpr (is_tuple<Current>::value) { // int tuple
|
||||
static_assert(sizeof(Shape) == 0, "Invalid parameters");
|
||||
} else { // int int
|
||||
if constexpr (is_constant<1, Shape>::value) {
|
||||
return Int<0>{}; // If current is dynamic, this could save a reg
|
||||
} else {
|
||||
return current;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Compact Col Major
|
||||
//
|
||||
|
||||
template <class Shape, class Current = Int<1>>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compact_col_major(Shape const& shape,
|
||||
Current const& current = {})
|
||||
{
|
||||
return compact_major(shape, current, GenColMajor{});
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
using ColMajor = decltype(compact_col_major(std::declval<Shape>()));
|
||||
|
||||
//
|
||||
// Compact Row Major
|
||||
//
|
||||
|
||||
template <class Shape, class Current = Int<1>>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compact_row_major(Shape const& shape,
|
||||
Current const& current = {})
|
||||
{
|
||||
return compact_major(shape, current, GenRowMajor{});
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
using RowMajor = decltype(compact_row_major(std::declval<Shape>()));
|
||||
|
||||
//
|
||||
// Compact Order -- compute a compact stride based on an ordering of the modes
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class Shape, class Order, class OrigShape, class OrigOrder>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compact_order(Shape const& shape, Order const& order,
|
||||
OrigShape const& orig_shape, OrigOrder const& orig_order)
|
||||
{
|
||||
if constexpr (is_tuple<Order>::value) {
|
||||
return transform(shape, order, [&](auto const& x, auto const& y) { return compact_order(x, y, orig_shape, orig_order); });
|
||||
} else {
|
||||
auto d = product(transform(orig_shape, orig_order,
|
||||
[&](auto const& s, auto const& o) {
|
||||
return conditional_return(o < order, product(s), Int<1>{});
|
||||
}));
|
||||
return compact_col_major(shape, d);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class Shape, class Order>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compact_order(Shape const& shape, Order const& order)
|
||||
{
|
||||
static_assert(is_congruent<Shape,Order>::value, "Need congruence of shape and order.");
|
||||
return detail::compact_order(shape, order, flatten_to_tuple(shape), flatten_to_tuple(order));
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compact_order(Shape const& shape, GenColMajor const& major)
|
||||
{
|
||||
return compact_major(shape, Int<1>{}, major);
|
||||
}
|
||||
|
||||
template <class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compact_order(Shape const& shape, GenRowMajor const& major)
|
||||
{
|
||||
return compact_major(shape, Int<1>{}, major);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
497
include/cute/swizzle.hpp
Normal file
497
include/cute/swizzle.hpp
Normal file
@ -0,0 +1,497 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/container/tuple.hpp>
|
||||
#include <cute/algorithm/tuple_algorithms.hpp>
|
||||
#include <cute/numeric/integer_sequence.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/numeric/math.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
// A generic Swizzle functor
|
||||
/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
|
||||
* ^--^ MBase is the number of least-sig bits to keep constant
|
||||
* ^-^ ^-^ BBits is the number of bits in the mask
|
||||
* ^---------^ SShift is the distance to shift the YYY mask
|
||||
* (pos shifts YYY to the right, neg shifts YYY to the left)
|
||||
*
|
||||
* e.g. Given
|
||||
* 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
|
||||
* the result is
|
||||
* 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
|
||||
*/
|
||||
template <int BBits, int MBase, int SShift = BBits>
|
||||
struct Swizzle
|
||||
{
|
||||
static constexpr int num_bits = BBits;
|
||||
static constexpr int num_base = MBase;
|
||||
static constexpr int num_shft = SShift;
|
||||
|
||||
static_assert(num_base >= 0, "MBase must be positive.");
|
||||
static_assert(num_bits >= 0, "BBits must be positive.");
|
||||
static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits.");
|
||||
|
||||
// using 'int' type here to avoid unintentially casting to unsigned... unsure.
|
||||
using bit_msk = cute::constant<int, (1 << num_bits) - 1>;
|
||||
using yyy_msk = cute::constant<int, bit_msk{} << (num_base + max(0,num_shft))>;
|
||||
using zzz_msk = cute::constant<int, bit_msk{} << (num_base - min(0,num_shft))>;
|
||||
using msk_sft = cute::constant<int, num_shft>;
|
||||
|
||||
static constexpr uint32_t swizzle_code = uint32_t(yyy_msk{} | zzz_msk{});
|
||||
|
||||
template <class Offset,
|
||||
__CUTE_REQUIRES(is_integral<Offset>::value)>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
apply(Offset const& offset)
|
||||
{
|
||||
return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY
|
||||
}
|
||||
|
||||
template <class Offset,
|
||||
__CUTE_REQUIRES(is_integral<Offset>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator()(Offset const& offset) const
|
||||
{
|
||||
return apply(offset);
|
||||
}
|
||||
};
|
||||
|
||||
// Translation for legacy SwizzleXor
|
||||
// TODO: Deprecate
|
||||
template <uint32_t BBits, uint32_t MBase, uint32_t SShift = 0>
|
||||
using SwizzleXor = Swizzle<BBits, MBase, SShift+BBits>;
|
||||
|
||||
//
|
||||
// make_swizzle<0b1000, 0b0100>() -> Swizzle<1,2,1>
|
||||
// make_swizzle<0b11000000, 0b00000110>() -> Swizzle<2,1,5>
|
||||
//
|
||||
|
||||
template <uint32_t Y, uint32_t Z>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_swizzle()
|
||||
{
|
||||
constexpr uint32_t BZ = popcount(Y); // Number of swizzle bits
|
||||
constexpr uint32_t BY = popcount(Z); // Number of swizzle bits
|
||||
static_assert(BZ == BY, "Number of bits in Y and Z don't match");
|
||||
constexpr uint32_t TZ_Y = countr_zero(Y); // Number of trailing zeros in Y
|
||||
constexpr uint32_t TZ_Z = countr_zero(Z); // Number of trailing zeros in Z
|
||||
constexpr uint32_t M = cute::min(TZ_Y, TZ_Z) % 32;
|
||||
constexpr int32_t S = int32_t(TZ_Y) - int32_t(TZ_Z); // Difference in trailing zeros
|
||||
static_assert((Y | Z) == Swizzle<BZ,M,S>::swizzle_code, "Something went wrong.");
|
||||
return Swizzle<BZ,M,S>{};
|
||||
}
|
||||
|
||||
template <int B0, int M0, int S0,
|
||||
int B1, int M1, int S1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
composition(Swizzle<B0,M0,S0>, Swizzle<B1,M1,S1>)
|
||||
{
|
||||
static_assert(S0 == S1, "Can only merge swizzles of the same shift.");
|
||||
constexpr uint32_t Y = Swizzle<B0,M0,S0>::yyy_msk::value ^ Swizzle<B1,M1,S1>::yyy_msk::value;
|
||||
constexpr uint32_t Z = Swizzle<B0,M0,S0>::zzz_msk::value ^ Swizzle<B1,M1,S1>::zzz_msk::value;
|
||||
return make_swizzle<Y,Z>();
|
||||
|
||||
//return ComposedFn<Swizzle<B0,M0,S0>, Swizzle<B1,M1,S1>>{};
|
||||
}
|
||||
|
||||
//
|
||||
// Upcast and Downcast
|
||||
//
|
||||
|
||||
template <int N, int B, int M, int S>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
upcast(Swizzle<B,M,S> const& swizzle)
|
||||
{
|
||||
static_assert(has_single_bit(N), "N must be a power of two");
|
||||
constexpr int log2_n = bit_width(uint32_t(N)) - 1;
|
||||
constexpr int NewM = M - log2_n;
|
||||
if constexpr (NewM >= 0) {
|
||||
return Swizzle<B,NewM,S>{};
|
||||
} else {
|
||||
return Swizzle<cute::max(B+NewM,0), 0, S>{};
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <int N, int B, int M, int S>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
downcast(Swizzle<B,M,S> const& swizzle)
|
||||
{
|
||||
static_assert(has_single_bit(N), "N must be a power of two");
|
||||
constexpr int log2_n = bit_width(uint32_t(N)) - 1;
|
||||
return Swizzle<B,(M + log2_n),S>{};
|
||||
}
|
||||
|
||||
template <class OldType, class NewType,
|
||||
int B, int M, int S>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(Swizzle<B,M,S> const& swizzle)
|
||||
{
|
||||
if constexpr (sizeof_bits<NewType>::value == sizeof_bits<OldType>::value) {
|
||||
return swizzle;
|
||||
} else if constexpr (sizeof_bits<NewType>::value > sizeof_bits<OldType>::value) {
|
||||
static_assert(sizeof_bits<NewType>::value % sizeof_bits<OldType>::value == 0, "NewType must be a multiple of OldType");
|
||||
return upcast<sizeof_bits<NewType>::value/sizeof_bits<OldType>::value>(swizzle);
|
||||
} else if constexpr (sizeof_bits<NewType>::value < sizeof_bits<OldType>::value) {
|
||||
static_assert(sizeof_bits<OldType>::value % sizeof_bits<NewType>::value == 0, "NewType must be a divisor of OldType");
|
||||
return downcast<sizeof_bits<OldType>::value/sizeof_bits<NewType>::value>(swizzle);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Utility for slicing and swizzle "offsets"
|
||||
//
|
||||
|
||||
// For swizzle functions, it is often needed to keep track of which bits are
|
||||
// consumed and which bits are free. Furthermore, it is useful to know whether
|
||||
// each of these bits is known statically or dynamically.
|
||||
|
||||
// MixedBits is an integer class where some bits are known statically and some
|
||||
// bits are known dynamically. These sets of bits are disjoint and it is known
|
||||
// statically which bits are known dynamically.
|
||||
|
||||
// MixedBits can only be manipulated through bitwise operations
|
||||
|
||||
// Abstract value: StaticInt | (dynamic_int_ & StaticFlags)
|
||||
template <uint32_t StaticInt = 0,
|
||||
class DynamicType = uint32_t,
|
||||
uint32_t StaticFlags = 0> // 0: static, 1: dynamic
|
||||
struct MixedBits
|
||||
{
|
||||
// Representation invariants
|
||||
static_assert(StaticFlags != 0, "Should be at least one dynamic bit in MixedBits.");
|
||||
static_assert((StaticInt & StaticFlags) == 0, "No static/dynamic overlap allowed in MixedBits.");
|
||||
// assert((dynamic_int_ & ~F) == 0);
|
||||
|
||||
DynamicType dynamic_int_;
|
||||
};
|
||||
|
||||
template <class S, S s, class DynamicType, class F, F f>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_mixed_bits(constant<S,s> const&, DynamicType const& d, constant<F,f> const&)
|
||||
{
|
||||
static_assert(is_integral<DynamicType>::value);
|
||||
if constexpr (is_static<DynamicType>::value) {
|
||||
static_assert((s & DynamicType::value & f) == 0, "No static/dynamic overlap allowed.");
|
||||
return constant<S,s>{} | (d & constant<F,f>{}); // Just return a static int
|
||||
} else if constexpr (f == 0) {
|
||||
return constant<S,s>{}; // Just return a static int
|
||||
} else {
|
||||
return MixedBits<s, DynamicType, f>{d & f}; // MixedBits
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Explicit conversion for now -- consider casting on plus or minus
|
||||
//
|
||||
|
||||
template <uint32_t S, class D, uint32_t F>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_integral(MixedBits<S,D,F> const& m)
|
||||
{
|
||||
//return S | (m.dynamic_int_ & F);
|
||||
return S | m.dynamic_int_;
|
||||
}
|
||||
|
||||
// Any cute::is_integral
|
||||
template <class I, __CUTE_REQUIRES(cute::is_integral<I>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_integral(I const& i)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
|
||||
//
|
||||
// Operators
|
||||
//
|
||||
|
||||
// Equality
|
||||
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(MixedBits<S0,D0,F0> const& m, constant<TS1,S1> const&)
|
||||
{
|
||||
return (S0 == (S1 & ~F0)) && (m.dynamic_int_ == (S1 & F0));
|
||||
}
|
||||
|
||||
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator==(constant<TS1,S1> const& s, MixedBits<S0,D0,F0> const& m)
|
||||
{
|
||||
return m == s;
|
||||
}
|
||||
|
||||
// Bitwise AND
|
||||
template <uint32_t S0, class D0, uint32_t F0,
|
||||
uint32_t S1, class D1, uint32_t F1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator&(MixedBits<S0,D0,F0> const& m0, MixedBits<S1,D1,F1> const& m1)
|
||||
{
|
||||
// Truth table for (S0,D0,F0) & (S1,D1,F1) -> (S,D,F)
|
||||
// S0D0F0 | 0X0 | 001 | 011 | 1X0 |
|
||||
// S1D1F1
|
||||
// 0X0 | 0X0 | 0X0 | 0X0 | 0X0 |
|
||||
// 001 | 0X0 | 001 | 001 | 001 |
|
||||
// 011 | 0X0 | 001 | 011 | 011 |
|
||||
// 1X0 | 0X0 | 001 | 011 | 1X0 |
|
||||
|
||||
return make_mixed_bits(constant<uint32_t,S0 & S1>{},
|
||||
//(S0 | m0.dynamic_int_) & (S1 | m1.dynamic_int_),
|
||||
((S1 & F0) & m0.dynamic_int_) | ((S0 & F1) & m1.dynamic_int_) | (m0.dynamic_int_ & m1.dynamic_int_),
|
||||
constant<uint32_t,(S1 & F0) | (S0 & F1) | (F0 & F1)>{});
|
||||
}
|
||||
|
||||
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator&(MixedBits<S0,D0,F0> const& m, constant<TS1,S1> const&)
|
||||
{
|
||||
return make_mixed_bits(constant<uint32_t,S0 & S1>{},
|
||||
m.dynamic_int_,
|
||||
constant<uint32_t,S1 & F0>{});
|
||||
}
|
||||
|
||||
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator&(constant<TS1,S1> const& s, MixedBits<S0,D0,F0> const& m)
|
||||
{
|
||||
return m & s;
|
||||
}
|
||||
|
||||
// Bitwise OR
|
||||
template <uint32_t S0, class D0, uint32_t F0,
|
||||
uint32_t S1, class D1, uint32_t F1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator|(MixedBits<S0,D0,F0> const& m0, MixedBits<S1,D1,F1> const& m1)
|
||||
{
|
||||
// Truth table for (S0,D0,F0) | (S1,D1,F1) -> (S,D,F)
|
||||
// S0D0F0 | 0X0 | 001 | 011 | 1X0 |
|
||||
// S1D1F1
|
||||
// 0X0 | 0X0 | 001 | 011 | 1X0 |
|
||||
// 001 | 001 | 001 | 011 | 1X0 |
|
||||
// 011 | 011 | 011 | 011 | 1X0 |
|
||||
// 1X0 | 1X0 | 1X0 | 1X0 | 1X0 |
|
||||
|
||||
return make_mixed_bits(constant<uint32_t,S0 | S1>{},
|
||||
((~S1 & F0) & m0.dynamic_int_) | ((~S0 & F1) & m1.dynamic_int_),
|
||||
constant<uint32_t,(~S0 & F1) | (~S1 & F0)>{});
|
||||
}
|
||||
|
||||
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator|(MixedBits<S0,D0,F0> const& m, constant<TS1,S1> const&)
|
||||
{
|
||||
return make_mixed_bits(constant<uint32_t,S0 | S1>{},
|
||||
m.dynamic_int_,
|
||||
constant<uint32_t,~S1 & F0>{});
|
||||
}
|
||||
|
||||
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator|(constant<TS1,S1> const& s, MixedBits<S0,D0,F0> const& m)
|
||||
{
|
||||
return m | s;
|
||||
}
|
||||
|
||||
// Bitwise XOR
|
||||
template <uint32_t S0, class D0, uint32_t F0,
|
||||
uint32_t S1, class D1, uint32_t F1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator^(MixedBits<S0,D0,F0> const& m0, MixedBits<S1,D1,F1> const& m1)
|
||||
{
|
||||
// Truth table for (S0,D0,F0) ^ (S1,D1,F1) -> (S,D,F)
|
||||
// S0D0F0 | 0X0 | 001 | 011 | 1X0 |
|
||||
// S1D1F1
|
||||
// 0X0 | 0X0 | 001 | 011 | 1X0 |
|
||||
// 001 | 001 | 001 | 011 | 011 |
|
||||
// 011 | 011 | 011 | 001 | 001 |
|
||||
// 1X0 | 1X0 | 011 | 001 | 0X0 |
|
||||
|
||||
return make_mixed_bits(constant<uint32_t,(~S0 & S1 & ~F0) | (S0 & ~S1 & ~F1)>{},
|
||||
(S0 | m0.dynamic_int_) ^ (S1 | m1.dynamic_int_),
|
||||
constant<uint32_t,F0 | F1>{});
|
||||
}
|
||||
|
||||
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator^(MixedBits<S0,D0,F0> const& m, constant<TS1,S1> const&)
|
||||
{
|
||||
return make_mixed_bits(constant<uint32_t,(~S0 & S1 & ~F0) | (S0 & ~S1)>{},
|
||||
(S0 | m.dynamic_int_) ^ S1,
|
||||
constant<uint32_t,F0>{});
|
||||
}
|
||||
|
||||
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
operator^(constant<TS1,S1> const& s, MixedBits<S0,D0,F0> const& m)
|
||||
{
|
||||
return m ^ s;
|
||||
}
|
||||
|
||||
//
|
||||
// upcast and downcast
|
||||
//
|
||||
|
||||
template <uint32_t S0, class D0, uint32_t F0, class TS1, TS1 S1>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
safe_div(MixedBits<S0,D0,F0> const& m, constant<TS1,S1> const& s)
|
||||
{
|
||||
static_assert(has_single_bit(S1), "Only divide MixedBits by powers of two.");
|
||||
return make_mixed_bits(safe_div(constant<uint32_t,S0>{}, s),
|
||||
safe_div(m.dynamic_int_, s),
|
||||
safe_div(constant<uint32_t,F0>{}, s));
|
||||
}
|
||||
|
||||
template <uint32_t N, uint32_t S0, class D0, uint32_t F0>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
upcast(MixedBits<S0,D0,F0> const& m)
|
||||
{
|
||||
static_assert(has_single_bit(N), "Only divide MixedBits by powers of two.");
|
||||
return safe_div(m, constant<uint32_t,N>{});
|
||||
}
|
||||
|
||||
template <uint32_t N, class T, __CUTE_REQUIRES(cute::is_integral<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
upcast(T const& m)
|
||||
{
|
||||
return safe_div(m, constant<uint32_t,N>{});
|
||||
}
|
||||
|
||||
template <uint32_t N, uint32_t S0, class D0, uint32_t F0>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
downcast(MixedBits<S0,D0,F0> const& m)
|
||||
{
|
||||
static_assert(has_single_bit(N), "Only scale MixedBits by powers of two.");
|
||||
return make_mixed_bits(constant<uint32_t,S0 * N>{},
|
||||
m.dynamic_int_ * N,
|
||||
constant<uint32_t,F0 * N>{});
|
||||
}
|
||||
|
||||
template <uint32_t N, class T, __CUTE_REQUIRES(cute::is_integral<T>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
downcast(T const& m)
|
||||
{
|
||||
return m * constant<uint32_t, N>{};
|
||||
}
|
||||
|
||||
//
|
||||
// Convert a Pow2Layout+Coord to a MixedBits
|
||||
//
|
||||
|
||||
template <class Shape, class Stride, class Coord>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_mixed_bits(Shape const& shape, Stride const& stride, Coord const& coord)
|
||||
{
|
||||
if constexpr (is_tuple<Shape>::value && is_tuple<Stride>::value && is_tuple<Coord>::value) {
|
||||
static_assert(tuple_size<Shape>::value == tuple_size<Stride>::value, "Mismatched ranks");
|
||||
static_assert(tuple_size<Shape>::value == tuple_size<Coord >::value, "Mismatched ranks");
|
||||
return transform_apply(shape, stride, coord, [](auto const& s, auto const& d, auto const& c) { return to_mixed_bits(s,d,c); },
|
||||
[](auto const&... a) { return (a ^ ...); });
|
||||
} else if constexpr (is_integral<Shape>::value && is_integral<Stride>::value && is_integral<Coord>::value) {
|
||||
static_assert(decltype(shape*stride)::value == 0 || has_single_bit(decltype(shape*stride)::value), "Requires pow2 shape*stride.");
|
||||
return make_mixed_bits(Int<0>{}, coord * stride, (shape - Int<1>{}) * stride);
|
||||
} else {
|
||||
static_assert(is_integral<Shape>::value && is_integral<Stride>::value && is_integral<Coord>::value, "Either Shape, Stride, and Coord must be all tuples, or they must be all integral (in the sense of cute::is_integral).");
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class Layout, class Coord>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_mixed_bits(Layout const& layout, Coord const& coord)
|
||||
{
|
||||
return to_mixed_bits(layout.shape(), layout.stride(), idx2crd(coord, layout.shape()));
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
template <uint32_t S, class D, uint32_t F>
|
||||
CUTE_HOST_DEVICE void print(MixedBits<S,D,F> const& m)
|
||||
{
|
||||
printf("M_%u|(%u&%u)=%u", S, uint32_t(m.dynamic_int_), F, to_integral(m));
|
||||
}
|
||||
|
||||
template <uint32_t S, class D, uint32_t F>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits<S,D,F> const& m)
|
||||
{
|
||||
return os << "M_" << S << "|(" << uint32_t(m.dynamic_int_) << "&" << F << ")=" << to_integral(m);
|
||||
}
|
||||
|
||||
template <int B, int M, int S>
|
||||
CUTE_HOST_DEVICE void print(Swizzle<B,M,S> const&)
|
||||
{
|
||||
print("S<%d,%d,%d>", B, M, S);
|
||||
}
|
||||
|
||||
template <int B, int M, int S>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, Swizzle<B,M,S> const&)
|
||||
{
|
||||
return os << "S<" << B << "," << M << "," << S << ">";
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
1010
include/cute/swizzle_layout.hpp
Normal file
1010
include/cute/swizzle_layout.hpp
Normal file
File diff suppressed because it is too large
Load Diff
282
include/cute/swizzle_ptr.hpp
Normal file
282
include/cute/swizzle_ptr.hpp
Normal file
@ -0,0 +1,282 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/util.hpp>
|
||||
|
||||
#include <cute/swizzle.hpp>
|
||||
#include <cute/swizzle_layout.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cute/pointer.hpp>
|
||||
#include <cute/container/array.hpp>
|
||||
#include <cute/numeric/int.hpp>
|
||||
|
||||
/* This implements a swizzle pointer of the form
|
||||
* InvolutionFn o PtrAdd
|
||||
* where the InvolutionFn need not be linear.
|
||||
*
|
||||
* This differs subtly from swizzle_layout because the smem pointer is used
|
||||
* as the offset. That means that swizzle_layout will implement position-independent
|
||||
* swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors.
|
||||
* Arch chose to design hardware with position-dependent swizzles.
|
||||
*
|
||||
* For clarity:
|
||||
* NormalLayout : DeRef <- PtrAdd <- [Layout]
|
||||
* ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout]
|
||||
* SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout
|
||||
*
|
||||
* Furthermore, for known swizzles, this pointer attempts to decay itself
|
||||
* to a normal-pointer with a new layout containing dynamic or static strides.
|
||||
* This is possible by determining the subdomain of the InvolutionFn
|
||||
* that is identity and testing if the Layout's codomain is contained
|
||||
* within it.
|
||||
*/
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class T, class Swizzle>
|
||||
struct smem_ptr_swizzle
|
||||
{
|
||||
static_assert(std::is_empty<Swizzle>::value, "Swizzle can't have state.");
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T* get() const
|
||||
{
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
Swizzle get_swizzle()
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
T* apply_swizzle(T* ptr)
|
||||
{
|
||||
return reinterpret_cast<T*>(Swizzle::apply(reinterpret_cast<std::uintptr_t>(ptr)));
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T& operator*() const
|
||||
{
|
||||
return *apply_swizzle(get());
|
||||
}
|
||||
|
||||
template <class Int>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T& operator[](Int const& i) const
|
||||
{
|
||||
return *apply_swizzle(get() + i);
|
||||
}
|
||||
|
||||
template <class Int>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
smem_ptr_swizzle operator+(Int const& i) const
|
||||
{
|
||||
return {ptr_ + i};
|
||||
}
|
||||
|
||||
T* ptr_;
|
||||
};
|
||||
|
||||
template <class T, class S>
|
||||
struct is_smem<smem_ptr_swizzle<T,S>> : true_type {};
|
||||
|
||||
// Make a swizzle pointer
|
||||
template <class T, class Swizzle>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_smem_ptr(T* ptr, Swizzle const& swizzle)
|
||||
{
|
||||
return smem_ptr_swizzle<T,Swizzle>{ptr};
|
||||
}
|
||||
|
||||
// A model of a nullptr smem_ptr<T> with B == sizeof_bits<T>::value
|
||||
// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr
|
||||
template <int Bits>
|
||||
struct smem_ptr_flag_bits : Int<0> {};
|
||||
|
||||
using smem_ptr_flag = smem_ptr_flag_bits<1>;
|
||||
|
||||
// A flagged construction method to transform ComposedLayout
|
||||
// Make a swizzle pointer tensor and check that the intended type size matches
|
||||
template <class T, class Swizzle, int B, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tensor(smem_ptr<T> const& ptr,
|
||||
ComposedLayout<Swizzle,smem_ptr_flag_bits<B>,Layout> const& layout)
|
||||
{
|
||||
static_assert(B == sizeof_bits<T>::value, "Expected a B-bit pointer type.");
|
||||
return make_tensor(make_smem_ptr(ptr.get(), layout.swizzle_fn()),
|
||||
layout.layout_fn());
|
||||
}
|
||||
|
||||
// Specialization for immediate decay
|
||||
template <class T, int M, int S, class LShape, class LStride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tensor(smem_ptr_swizzle<T,Swizzle<0,M,S>>& p, Layout<LShape,LStride> const& layout)
|
||||
{
|
||||
return make_tensor(make_smem_ptr(p.ptr_), layout);
|
||||
}
|
||||
|
||||
template <class T, int M, int S, class LShape, class LStride>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tensor(smem_ptr_swizzle<T,Swizzle<0,M,S>> const& p, Layout<LShape,LStride> const& layout)
|
||||
{
|
||||
return make_tensor(make_smem_ptr(p.ptr_), layout);
|
||||
}
|
||||
|
||||
// NOTE: To preserve smem_ptr_flag_bits under recast ops
|
||||
template <int N, class Swizzle, int B, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
upcast(ComposedLayout<Swizzle,smem_ptr_flag_bits<B>,Layout> const& layout)
|
||||
{
|
||||
return composition(layout.swizzle_fn(), smem_ptr_flag_bits<B*N>{}, upcast<N>(layout.layout_fn()));
|
||||
}
|
||||
|
||||
template <int N, class Swizzle, int B, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
downcast(ComposedLayout<Swizzle,smem_ptr_flag_bits<B>,Layout> const& layout)
|
||||
{
|
||||
return composition(layout.swizzle_fn(), smem_ptr_flag_bits<B/N>{}, downcast<N>(layout.layout_fn()));
|
||||
}
|
||||
|
||||
//
|
||||
// Recast
|
||||
// Swizzle operates on the pointer address, so it doesn't care about the type
|
||||
//
|
||||
|
||||
template <class NewT, class T, class Swizzle>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(smem_ptr_swizzle<T,Swizzle> const& ptr)
|
||||
{
|
||||
return smem_ptr_swizzle<NewT,Swizzle>{recast<NewT>(ptr.ptr_)};
|
||||
}
|
||||
|
||||
template <class NewT, class T, class Swizzle>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(smem_ptr_swizzle<T const,Swizzle> const& ptr)
|
||||
{
|
||||
return smem_ptr_swizzle<NewT const,Swizzle>{recast<NewT const>(ptr.ptr_)};
|
||||
}
|
||||
|
||||
//
|
||||
// Conversion with swizzle_layout
|
||||
//
|
||||
|
||||
template <class T, class Swizzle, int B, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
as_position_independent_swizzle_layout(ComposedLayout<Swizzle,smem_ptr_flag_bits<B>,Layout> const& layout)
|
||||
{
|
||||
return composition(recast<uint_bit_t<8>,uint_bit_t<B>>(layout.swizzle_fn()), Int<0>{}, layout.layout_fn());
|
||||
}
|
||||
|
||||
template <class T, class Swizzle, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
as_position_independent_swizzle_tensor(Tensor<ViewEngine<smem_ptr_swizzle<T,Swizzle>>, Layout> const& tensor)
|
||||
{
|
||||
{
|
||||
uint32_t address = cast_smem_ptr_to_uint(tensor.data().get());
|
||||
uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code);
|
||||
assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle
|
||||
}
|
||||
auto new_swizzle = recast<uint_bit_t<8>,uint_bit_t<sizeof_bits_v<T>>>(tensor.data().get_swizzle());
|
||||
return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout()));
|
||||
}
|
||||
|
||||
template <class T, class Swizzle, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
as_position_independent_swizzle_tensor(Tensor<ViewEngine<smem_ptr_swizzle<T,Swizzle>>, Layout>& tensor)
|
||||
{
|
||||
{
|
||||
uint32_t address = cast_smem_ptr_to_uint(tensor.data().get());
|
||||
uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code);
|
||||
assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle
|
||||
}
|
||||
auto new_swizzle = recast<uint_bit_t<8>,uint_bit_t<sizeof_bits_v<T>>>(tensor.data().get_swizzle());
|
||||
return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout()));
|
||||
}
|
||||
|
||||
template <class T, class Swizzle, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
as_position_independent_swizzle_tensor(Tensor<ViewEngine<smem_ptr_swizzle<T,Swizzle>>, Layout>&& tensor)
|
||||
{
|
||||
return as_position_independent_swizzle_tensor(tensor);
|
||||
}
|
||||
|
||||
//
|
||||
// Print
|
||||
//
|
||||
|
||||
// Capture and cast smem_ptr_flag Layouts to offset-0 layouts
|
||||
template <class Swizzle, int B, class Layout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_latex(ComposedLayout<Swizzle,smem_ptr_flag_bits<B>,Layout> const& layout)
|
||||
{
|
||||
auto new_swizzle = recast<uint_bit_t<8>,uint_bit_t<B>>(layout.swizzle_fn());
|
||||
print_latex(composition(new_swizzle, Int<0>{}, layout.layout_fn()));
|
||||
}
|
||||
|
||||
template <int B>
|
||||
CUTE_HOST_DEVICE void print(smem_ptr_flag_bits<B> const& ptr)
|
||||
{
|
||||
printf("smem_ptr_%db(unset)", B);
|
||||
}
|
||||
|
||||
template <class T, int B, int M, int S>
|
||||
CUTE_HOST_DEVICE void print(smem_ptr_swizzle<T,Swizzle<B,M,S>> const& ptr)
|
||||
{
|
||||
printf("smem_ptr_S<%d,%d,%d>_%db(%p)", B, M, S, int(8*sizeof(T)), ptr.get());
|
||||
}
|
||||
|
||||
template <class T, int B, int M, int S>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr_swizzle<T,Swizzle<B,M,S>> const&)
|
||||
{
|
||||
return os << "smem_ptr_S<" << B << "," << M << "," << S << ">_" << int(8*sizeof(T)) << "b";
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
900
include/cute/tensor.hpp
Normal file
900
include/cute/tensor.hpp
Normal file
@ -0,0 +1,900 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cute/container/tuple.hpp>
|
||||
#include <cute/container/array_aligned.hpp>
|
||||
#include <cute/container/array_subbyte.hpp>
|
||||
#include <cute/container/type_list.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/numeric/integer_sequence.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
#include <cute/tile.hpp>
|
||||
#include <cute/pointer.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// Engine -- owning or non-owning data store
|
||||
//
|
||||
|
||||
// concept Engine {
|
||||
// using value_type = ;
|
||||
// iterator begin();
|
||||
// };
|
||||
|
||||
template <class T, int N>
|
||||
using ArrayEngine = typename std::conditional<(sizeof_bits<T>::value % 8 == 0),
|
||||
array_aligned<T,N>,
|
||||
array_subbyte<T,N>>::type;
|
||||
|
||||
template <class Iterator>
|
||||
struct ViewEngine
|
||||
{
|
||||
using value_type = typename cute::remove_cvref<decltype(*std::declval<Iterator>())>::type;
|
||||
|
||||
using iterator = Iterator;
|
||||
iterator storage_;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator const&
|
||||
begin() const {
|
||||
return storage_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator&
|
||||
begin() {
|
||||
return storage_;
|
||||
}
|
||||
};
|
||||
|
||||
template <class Iter>
|
||||
struct is_rmem<ViewEngine<Iter>> : is_rmem<Iter> {};
|
||||
template <class Iter>
|
||||
struct is_smem<ViewEngine<Iter>> : is_smem<Iter> {};
|
||||
template <class Iter>
|
||||
struct is_gmem<ViewEngine<Iter>> : is_gmem<Iter> {};
|
||||
template <class Iterator>
|
||||
struct ConstViewEngine
|
||||
{
|
||||
using value_type = typename cute::remove_cvref<decltype(*std::declval<Iterator>())>::type;
|
||||
|
||||
using iterator = Iterator;
|
||||
iterator storage_;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
iterator const&
|
||||
begin() const {
|
||||
return storage_;
|
||||
}
|
||||
};
|
||||
|
||||
template <class Iter>
|
||||
struct is_rmem<ConstViewEngine<Iter>> : is_rmem<Iter> {};
|
||||
template <class Iter>
|
||||
struct is_smem<ConstViewEngine<Iter>> : is_smem<Iter> {};
|
||||
template <class Iter>
|
||||
struct is_gmem<ConstViewEngine<Iter>> : is_gmem<Iter> {};
|
||||
//
|
||||
// Tensor
|
||||
//
|
||||
|
||||
template <class Engine, class Layout>
|
||||
struct Tensor
|
||||
{
|
||||
using value_type = typename Engine::value_type;
|
||||
//using pointer = typename engine_traits<Engine>::pointer;
|
||||
//using const_pointer = typename engine_traits<Engine>::const_pointer;
|
||||
//using reference = typename engine_traits<Engine>::reference;
|
||||
//using const_reference = typename engine_traits<Engine>::const_reference;
|
||||
|
||||
using engine_type = Engine;
|
||||
using layout_type = Layout;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Tensor() {}
|
||||
|
||||
template <class Ptr>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Tensor(Ptr const& ptr, Layout const& layout)
|
||||
: rep_(layout, ptr) {
|
||||
}
|
||||
|
||||
//
|
||||
// Accessors
|
||||
//
|
||||
|
||||
static constexpr int rank = Layout::rank;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
tensor() const {
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
layout() const {
|
||||
return get<0>(rep_);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
engine() const {
|
||||
return get<1>(rep_);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
engine() {
|
||||
return get<1>(rep_);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
data() const {
|
||||
return engine().begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
data() {
|
||||
return engine().begin();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
shape() const {
|
||||
return layout().shape();
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
size() const {
|
||||
return cute::size(shape());
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
stride() const {
|
||||
return layout().stride();
|
||||
}
|
||||
|
||||
//
|
||||
// Indexing op() and op[]
|
||||
//
|
||||
|
||||
// Index into this tensor like an array by computing the offset via layout()
|
||||
template <class Coord>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
operator[](Coord const& coord) {
|
||||
return data()[layout()(coord)];
|
||||
}
|
||||
|
||||
template <class Coord>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
operator[](Coord const& coord) const {
|
||||
return data()[layout()(coord)];
|
||||
}
|
||||
|
||||
template <class Coord>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
operator()(Coord const& coord) {
|
||||
if constexpr (has_underscore<Coord>::value) {
|
||||
auto const& [sliced_layout,offset] = slice_and_offset(coord, layout());
|
||||
return make_tensor(data() + offset, sliced_layout);
|
||||
} else {
|
||||
return data()[layout()(coord)];
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class Coord>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
operator()(Coord const& coord) const {
|
||||
if constexpr (has_underscore<Coord>::value) {
|
||||
auto const& [sliced_layout,offset] = slice_and_offset(coord, layout());
|
||||
return make_tensor(data() + offset, sliced_layout);
|
||||
} else {
|
||||
return data()[layout()(coord)];
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// op() convenience function for multi-dimensional coordinates
|
||||
template <class Coord0, class Coord1, class... Coords>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) {
|
||||
return operator()(make_coord(c0,c1,cs...));
|
||||
}
|
||||
|
||||
template <class Coord0, class Coord1, class... Coords>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const {
|
||||
return operator()(make_coord(c0,c1,cs...));
|
||||
}
|
||||
|
||||
//
|
||||
// Compose
|
||||
//
|
||||
|
||||
template <class... Layouts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compose(Layouts const&... layouts) {
|
||||
return make_tensor(data(), layout().compose(layouts...));
|
||||
}
|
||||
|
||||
template <class... Layouts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
compose(Layouts const&... layouts) const {
|
||||
return make_tensor(data(), layout().compose(layouts...));
|
||||
}
|
||||
|
||||
//
|
||||
// Tile
|
||||
//
|
||||
|
||||
template <class... Layouts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tile(Layouts const&... layouts) {
|
||||
return make_tensor(data(), layout().tile(layouts...));
|
||||
}
|
||||
|
||||
template <class... Layouts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tile(Layouts const&... layouts) const {
|
||||
return make_tensor(data(), layout().tile(layouts...));
|
||||
}
|
||||
|
||||
//
|
||||
// Utility
|
||||
//
|
||||
|
||||
template <class Int,
|
||||
__CUTE_REQUIRES(is_integral<Int>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_1d_coord(Int const& linear_idx) const {
|
||||
return layout().get_1d_coord(linear_idx);
|
||||
}
|
||||
|
||||
template <class Int,
|
||||
__CUTE_REQUIRES(is_integral<Int>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_hier_coord(Int const& linear_idx) const {
|
||||
return layout().get_hier_coord(linear_idx);
|
||||
}
|
||||
|
||||
template <class Int,
|
||||
__CUTE_REQUIRES(is_integral<Int>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_flat_coord(Int const& linear_idx) const {
|
||||
return layout().get_flat_coord(linear_idx);
|
||||
}
|
||||
|
||||
cute::tuple<layout_type, engine_type> rep_;
|
||||
};
|
||||
|
||||
|
||||
template <class Layout>
|
||||
struct is_tensor : false_type {};
|
||||
template <class Engine, class Layout>
|
||||
struct is_tensor<Tensor<Engine,Layout>> : true_type {};
|
||||
|
||||
template <class Engine, class Layout>
|
||||
struct is_rmem<Tensor<Engine,Layout>> : is_rmem<Engine> {};
|
||||
template <class Engine, class Layout>
|
||||
struct is_smem<Tensor<Engine,Layout>> : is_smem<Engine> {};
|
||||
template <class Engine, class Layout>
|
||||
struct is_gmem<Tensor<Engine,Layout>> : is_gmem<Engine> {};
|
||||
//
|
||||
// Make an owning Tensor that will allocate a static array
|
||||
//
|
||||
|
||||
template <class T, class Layout,
|
||||
__CUTE_REQUIRES(is_layout<Layout>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tensor(Layout const& layout)
|
||||
{
|
||||
static_assert(is_static<Layout>::value, "Dynamic owning tensors not supported");
|
||||
using Engine = ArrayEngine<T, cosize_v<Layout>>;
|
||||
return Tensor<Engine,Layout>();
|
||||
}
|
||||
|
||||
// e.g. make_tensor<double>(12)
|
||||
template <class T, class LayoutArg, class... LayoutArgs,
|
||||
__CUTE_REQUIRES(not is_layout<LayoutArg>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tensor(LayoutArg const& arg, LayoutArgs const&... args)
|
||||
{
|
||||
return make_tensor<T>(make_layout(arg, args...));
|
||||
}
|
||||
|
||||
//
|
||||
// Make a non-owning Tensor that will use a pointer (view)
|
||||
//
|
||||
|
||||
template <class Iterator, class Layout,
|
||||
__CUTE_REQUIRES(has_dereference<Iterator>::value &&
|
||||
is_layout<Layout>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tensor(Iterator const& iter, Layout const& layout)
|
||||
{
|
||||
using Engine = ViewEngine<Iterator>;
|
||||
return Tensor<Engine,Layout>(iter, layout);
|
||||
}
|
||||
|
||||
// e.g. make_tensor(vec.data(), 12)
|
||||
template <class Iterator, class LayoutArg, class... LayoutArgs,
|
||||
__CUTE_REQUIRES(not is_layout<LayoutArg>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tensor(Iterator const& iter, LayoutArg const& arg, LayoutArgs const&... args)
|
||||
{
|
||||
return make_tensor(iter, make_layout(arg, args...));
|
||||
}
|
||||
|
||||
//
|
||||
// make_tensor_like -- make a register tensor the same type and shape as another
|
||||
//
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tensor_like(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
using value_type = typename Tensor<Engine,Layout>::value_type;
|
||||
return make_tensor<value_type>(tensor.shape());
|
||||
}
|
||||
|
||||
//
|
||||
// make_fragment_like -- make a register tensor the same type, shape, and (if possible) order as another tensor
|
||||
//
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_fragment_like(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
using value_type = typename Tensor<Engine,Layout>::value_type;
|
||||
return make_tensor<value_type>(make_layout_like(tensor.layout()));
|
||||
}
|
||||
|
||||
//
|
||||
// make_identity_tensor
|
||||
//
|
||||
|
||||
template <class Shape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_identity_tensor(Shape const& shape)
|
||||
{
|
||||
return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat_like(shape, Int<0>{}))),
|
||||
make_identity_layout(shape));
|
||||
}
|
||||
|
||||
//
|
||||
// Utilities
|
||||
//
|
||||
|
||||
// Return the subtensor of a mode
|
||||
template <class Tensor,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
tensor(Tensor&& tensor)
|
||||
{
|
||||
return std::forward<Tensor>(tensor);
|
||||
}
|
||||
|
||||
template <int I, int... Is, class Tensor,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
tensor(Tensor&& tensor)
|
||||
{
|
||||
return make_tensor(std::forward<Tensor>(tensor).data(), get<I,Is...>(tensor.layout()));
|
||||
}
|
||||
|
||||
// Return the subtensor of a range of modes
|
||||
template <int B, int E, class Tensor,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
take(Tensor&& tensor)
|
||||
{
|
||||
return make_tensor(std::forward<Tensor>(tensor).data(), take<B,E>(tensor.layout()));
|
||||
}
|
||||
|
||||
// Return the layout of a mode
|
||||
template <int... Is, class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
layout(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
return layout<Is...>(tensor.layout());
|
||||
}
|
||||
|
||||
// Return the shape of a mode
|
||||
template <int... Is, class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
shape(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
return shape<Is...>(tensor.layout());
|
||||
}
|
||||
|
||||
// Return the stride of a mode
|
||||
template <int... Is, class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
stride(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
return stride<Is...>(tensor.layout());
|
||||
}
|
||||
|
||||
// Return the number of elements in a mode
|
||||
template <int... Is, class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
decltype(auto)
|
||||
size(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
return size<Is...>(tensor.layout());
|
||||
}
|
||||
|
||||
// Return the rank of a mode
|
||||
template <int... Is, class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
rank(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
return rank<Is...>(tensor.layout());
|
||||
}
|
||||
|
||||
// Return the depth of a mode
|
||||
template <int... Is, class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
depth(Tensor<Engine, Layout> const& tensor)
|
||||
{
|
||||
return depth<Is...>(tensor.layout());
|
||||
}
|
||||
|
||||
//
|
||||
// Operations to manipulate Tensors like a Layout
|
||||
//
|
||||
|
||||
template <class Tensor,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
flatten(Tensor&& tensor)
|
||||
{
|
||||
return make_tensor(std::forward<Tensor>(tensor).data(), flatten(tensor.layout()));
|
||||
}
|
||||
|
||||
template <class Tensor,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
coalesce(Tensor&& tensor)
|
||||
{
|
||||
return make_tensor(std::forward<Tensor>(tensor).data(), coalesce(tensor.layout()));
|
||||
}
|
||||
|
||||
template <class Tensor, class Profile,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
coalesce(Tensor&& tensor, Profile const& profile)
|
||||
{
|
||||
return make_tensor(std::forward<Tensor>(tensor).data(), coalesce(tensor.layout(), profile));
|
||||
}
|
||||
|
||||
// Group the modes [B,E) into a single mode
|
||||
// e.g. group<2,4>(make_tensor<int>(Layout<Shape<_1,_2,_3,_4,_5,_6>>{}))
|
||||
// => make_tensor<int>(Layout<Shape<_1,_2,Shape<_3,_4>,_5,_6>>{})
|
||||
template <int B, int E, class Tensor,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
group_modes(Tensor&& tensor)
|
||||
{
|
||||
return make_tensor(std::forward<Tensor>(tensor).data(),
|
||||
group<B,E>(tensor.layout()));
|
||||
}
|
||||
|
||||
//
|
||||
// Recast
|
||||
//
|
||||
|
||||
// NOTE: This is very dangerous to do
|
||||
// -- doesn't check dynamic integer divisibility
|
||||
// -- doesn't check alignment
|
||||
|
||||
// A tagged version for dispatching
|
||||
template <class NewType, class Tensor,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(Tensor&& tensor, type_list<NewType>)
|
||||
{
|
||||
using OldType = typename remove_cvref_t<Tensor>::value_type;
|
||||
auto old_layout = tensor.layout();
|
||||
auto new_layout = recast<OldType,NewType>(old_layout);
|
||||
|
||||
// If this is an upcast of a normal Layout with static negative strides, then offset as well
|
||||
if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout<decltype(old_layout)>::value) {
|
||||
auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{});
|
||||
auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{});
|
||||
auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); });
|
||||
|
||||
return make_tensor(recast<NewType>(std::forward<Tensor>(tensor).data() + offset), new_layout);
|
||||
} else {
|
||||
return make_tensor(recast<NewType>(std::forward<Tensor>(tensor).data() ), new_layout);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <class NewType, class Tensor,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast(Tensor&& tensor)
|
||||
{
|
||||
return recast(std::forward<Tensor>(tensor), type_list<NewType>{});
|
||||
}
|
||||
|
||||
//
|
||||
// max_common_vector
|
||||
//
|
||||
|
||||
/* Return Int<N> such that N is the maximum number of continguous elements
|
||||
* that logically correspond in the tensors of @a a and @a b. This is,
|
||||
* the number of elements that could reasonably be vectorized into a single load/store.
|
||||
*
|
||||
* @returns Int<N> with N >= 0
|
||||
*
|
||||
* A return value of Int<0> indicates that no such conclusion can be made and no
|
||||
* vectorization should be attempted.
|
||||
*/
|
||||
template <class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
max_common_vector(Tensor<SrcEngine,SrcLayout> const& a,
|
||||
Tensor<DstEngine,DstLayout> const& b)
|
||||
{
|
||||
using SrcType = typename Tensor<SrcEngine,SrcLayout>::value_type;
|
||||
using DstType = typename Tensor<DstEngine,DstLayout>::value_type;
|
||||
|
||||
using SrcRef = decltype(*(a.data()));
|
||||
using DstRef = decltype(*(b.data()));
|
||||
|
||||
// Determine if vectorization candidates at all
|
||||
if constexpr (// Should be the same value_types, else the copy is also performing a cast
|
||||
sizeof(SrcType) == sizeof(DstType) &&
|
||||
// The types should be trivially copyable so that vectorization is valid
|
||||
std::is_trivially_copyable<SrcType>::value &&
|
||||
std::is_trivially_copyable<DstType>::value &&
|
||||
// Should be load/storing real data, rather than implicit iterators or such
|
||||
std::is_reference<SrcRef>::value &&
|
||||
std::is_reference<DstRef>::value)
|
||||
{
|
||||
return max_common_vector(a.layout(), b.layout());
|
||||
} else {
|
||||
return Int<0>{};
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Key algebraic operations
|
||||
//
|
||||
|
||||
template <class Tensor, class Tile,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
logical_divide(Tensor && tensor,
|
||||
Tile const& tile)
|
||||
{
|
||||
return make_tensor(std::forward<Tensor>(tensor).data(),
|
||||
logical_divide(tensor.layout(), tile));
|
||||
}
|
||||
|
||||
// zipped_divide is logical_divide with modes gathered into standard form ((BLK_A,BLK_B),(a,b))
|
||||
template <class Tensor, class Tile,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
zipped_divide(Tensor && tensor,
|
||||
Tile const& tile) // Layout or Tile<Layout...>
|
||||
{
|
||||
return make_tensor(std::forward<Tensor>(tensor).data(),
|
||||
zipped_divide(tensor.layout(), tile));
|
||||
}
|
||||
|
||||
// tiled_divide is logical_divide with the second output mode flattened ((BLK_A,BLK_B),a,b)
|
||||
template <class Tensor, class Tile,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
tiled_divide(Tensor && tensor,
|
||||
Tile const& tile) // Layout or Tile<Layout...>
|
||||
{
|
||||
return make_tensor(std::forward<Tensor>(tensor).data(),
|
||||
tiled_divide(tensor.layout(), tile));
|
||||
}
|
||||
|
||||
// logical_product on a Tensor doesn't make sense since it often increases cosize
|
||||
|
||||
//
|
||||
// Logicial Divide utilities: local_partition and local_tile
|
||||
//
|
||||
|
||||
template <class Tensor, class Tile, class Coord,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
local_partition(Tensor && tensor,
|
||||
Tile const& tile,
|
||||
Coord const& coord)
|
||||
{
|
||||
constexpr int R1 = decltype(rank(tensor))::value;
|
||||
|
||||
// Split the modes of tensor according to the modes of tile
|
||||
// zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...))
|
||||
|
||||
// The_coord is the coord into the first mode, flatten the rest
|
||||
return zipped_divide(std::forward<Tensor>(tensor), tile)(coord, repeat<R1>(_));
|
||||
}
|
||||
|
||||
template <class Tensor, class Tile, class Coord, class Projection,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
local_partition(Tensor && tensor,
|
||||
Tile const& tile,
|
||||
Coord const& coord,
|
||||
Projection const& proj)
|
||||
{
|
||||
return local_partition(std::forward<Tensor>(tensor),
|
||||
dice(proj, tile),
|
||||
dice(proj, coord));
|
||||
}
|
||||
|
||||
// Special case with Layout and Integral that extracts the coord first
|
||||
// e.g. local_partition(tensor, ThrLayout, threadIdx.x)
|
||||
template <class Tensor, class LShape, class LStride, class Index,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value &&
|
||||
is_integral<Index>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
local_partition(Tensor && tensor,
|
||||
Layout<LShape,LStride> const& tile,
|
||||
Index const& index)
|
||||
{
|
||||
return local_partition(std::forward<Tensor>(tensor),
|
||||
product_each(shape(tile)),
|
||||
tile.get_flat_coord(index));
|
||||
}
|
||||
|
||||
// Special case with Layout and Integral that extracts the coord first
|
||||
// e.g. local_partition(tensor, ThrLayout, threadIdx.x, Step<_1,X,_1>{})
|
||||
template <class Tensor, class LShape, class LStride, class Index, class Projection,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value &&
|
||||
is_integral<Index>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
local_partition(Tensor && tensor,
|
||||
Layout<LShape,LStride> const& tile,
|
||||
Index const& index,
|
||||
Projection const& proj)
|
||||
{
|
||||
return local_partition(std::forward<Tensor>(tensor),
|
||||
dice(proj, product_each(shape(tile))),
|
||||
dice(proj, tile).get_flat_coord(index));
|
||||
}
|
||||
|
||||
template <class Tensor, class Tile, class Coord,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
local_tile(Tensor && tensor,
|
||||
Tile const& tile,
|
||||
Coord const& coord)
|
||||
{
|
||||
constexpr int R0 = decltype(rank(tile))::value;
|
||||
constexpr int R1 = decltype(rank(tensor))::value;
|
||||
|
||||
// Split the modes of tensor according to the modes of tile
|
||||
// zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...))
|
||||
|
||||
// The padded_coord is the coord into the second mode, flatten the rest
|
||||
return zipped_divide(std::forward<Tensor>(tensor), tile)(repeat<R0>(_), append<R1>(coord,_));
|
||||
}
|
||||
|
||||
template <class Tensor, class Tile, class Coord, class Proj,
|
||||
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
local_tile(Tensor && tensor,
|
||||
Tile const& tile,
|
||||
Coord const& coord,
|
||||
Proj const& proj)
|
||||
{
|
||||
return local_tile(std::forward<Tensor>(tensor),
|
||||
dice(proj, tile),
|
||||
dice(proj, coord));
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
auto format = get_format(tensor(0));
|
||||
using type = typename decltype(format)::type;
|
||||
|
||||
if constexpr (Layout::rank == 1)
|
||||
{
|
||||
for (int m = 0; m < size(tensor); ++m) {
|
||||
printf(format.format, format.digits, type(tensor(m)));
|
||||
printf("\n");
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 2)
|
||||
{
|
||||
for (int m = 0; m < size<0>(tensor); ++m) {
|
||||
for (int n = 0; n < size<1>(tensor); ++n) {
|
||||
printf(format.format, format.digits, type(tensor(m,n)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 3)
|
||||
{
|
||||
print_tensor(tensor(_,_,0));
|
||||
for (int k = 1; k < size<2>(tensor); ++k) {
|
||||
for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("-"); } print("\n");
|
||||
print_tensor(tensor(_,_,k));
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 4)
|
||||
{
|
||||
print_tensor(tensor(_,_,_,0));
|
||||
for (int p = 1; p < size<3>(tensor); ++p) {
|
||||
for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("="); } print("\n");
|
||||
print_tensor(tensor(_,_,_,p));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST_DEVICE void print(Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
print(tensor.layout()); print("\n");
|
||||
print_tensor(tensor);
|
||||
}
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
int digits = 9;
|
||||
|
||||
if constexpr (Layout::rank == 1)
|
||||
{
|
||||
for (int m = 0; m < size(tensor); ++m) {
|
||||
os << std::setw(digits) << tensor(m) << std::endl;
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 2)
|
||||
{
|
||||
for (int m = 0; m < size<0>(tensor); ++m) {
|
||||
for (int n = 0; n < size<1>(tensor); ++n) {
|
||||
os << std::setw(digits) << tensor(m,n);
|
||||
}
|
||||
os << std::endl;
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 3)
|
||||
{
|
||||
print_tensor_os(os, tensor(_,_,0));
|
||||
for (int k = 1; k < size<2>(tensor); ++k) {
|
||||
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl;
|
||||
print_tensor_os(os, tensor(_,_,k));
|
||||
}
|
||||
} else
|
||||
if constexpr (Layout::rank == 4)
|
||||
{
|
||||
print_tensor_os(os, tensor(_,_,_,0));
|
||||
for (int p = 1; p < size<3>(tensor); ++p) {
|
||||
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl;
|
||||
print_tensor_os(os, tensor(_,_,_,p));
|
||||
}
|
||||
}
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
template <class Engine, class Layout>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor<Engine,Layout> const& tensor)
|
||||
{
|
||||
os << tensor.layout() << std::endl;
|
||||
return print_tensor_os(os, tensor);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
//
|
||||
// Extended Engines
|
||||
//
|
||||
|
||||
#include <cute/swizzle_ptr.hpp>
|
||||
|
||||
//
|
||||
// Tensor Algorithms
|
||||
//
|
||||
|
||||
#include <cute/algorithm/tensor_algorithms.hpp>
|
||||
#include <cute/algorithm/fill.hpp>
|
||||
#include <cute/algorithm/clear.hpp>
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/algorithm/axpby.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
63
include/cute/tensor_predicate.hpp
Normal file
63
include/cute/tensor_predicate.hpp
Normal file
@ -0,0 +1,63 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <class T>
|
||||
struct ConstantTensor
|
||||
{
|
||||
template <class... Coords>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
T const&
|
||||
operator()(Coords const&...) const {
|
||||
return val_;
|
||||
}
|
||||
|
||||
T val_;
|
||||
};
|
||||
|
||||
struct TrivialPredTensor
|
||||
{
|
||||
template <class... Coords>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
true_type
|
||||
operator()(Coords const&...) const {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace cute
|
||||
58
include/cute/tile.hpp
Normal file
58
include/cute/tile.hpp
Normal file
@ -0,0 +1,58 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/layout.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//
|
||||
// A Tile is not a Layout, it's a tuple of Layouts or Tiles or Underscores
|
||||
//
|
||||
|
||||
template <class... Layouts>
|
||||
using Tile = tuple<Layouts...>;
|
||||
|
||||
template <class Tile>
|
||||
using is_tile = is_tuple<Tile>;
|
||||
|
||||
template <class... Layouts>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
make_tile(Layouts const&... layouts)
|
||||
{
|
||||
return Tile<Layouts...>(layouts...);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
148
include/cute/underscore.hpp
Normal file
148
include/cute/underscore.hpp
Normal file
@ -0,0 +1,148 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/container/tuple.hpp>
|
||||
#include <cute/algorithm/tuple_algorithms.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
#include <cute/numeric/integer_sequence.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
// For slicing
|
||||
struct Underscore : Int<0> {};
|
||||
|
||||
CUTE_INLINE_CONSTANT Underscore _;
|
||||
|
||||
// Treat Underscore as an integral like integral_constant
|
||||
template <>
|
||||
struct is_integral<Underscore> : true_type {};
|
||||
|
||||
template <class T>
|
||||
struct is_underscore : false_type {};
|
||||
template <>
|
||||
struct is_underscore<Underscore> : true_type {};
|
||||
|
||||
// Tuple trait for detecting static member element
|
||||
template <class Tuple, class Elem, class Enable = void>
|
||||
struct has_elem : false_type {};
|
||||
template <class Elem>
|
||||
struct has_elem<Elem, Elem> : true_type {};
|
||||
template <class Tuple, class Elem>
|
||||
struct has_elem<Tuple, Elem, std::enable_if_t<is_tuple<Tuple>::value> >
|
||||
: has_elem<Tuple, Elem, tuple_seq<Tuple> > {};
|
||||
template <class Tuple, class Elem, int... Is>
|
||||
struct has_elem<Tuple, Elem, seq<Is...>>
|
||||
: disjunction<has_elem<std::tuple_element_t<Is, Tuple>, Elem>...> {};
|
||||
|
||||
// Tuple trait for detecting static member element
|
||||
template <class Tuple, class Elem, class Enable = void>
|
||||
struct all_elem : false_type {};
|
||||
template <class Elem>
|
||||
struct all_elem<Elem, Elem> : true_type {};
|
||||
template <class Tuple, class Elem>
|
||||
struct all_elem<Tuple, Elem, std::enable_if_t<is_tuple<Tuple>::value> >
|
||||
: all_elem<Tuple, Elem, tuple_seq<Tuple> > {};
|
||||
template <class Tuple, class Elem, int... Is>
|
||||
struct all_elem<Tuple, Elem, seq<Is...>>
|
||||
: conjunction<all_elem<std::tuple_element_t<Is, Tuple>, Elem>...> {};
|
||||
|
||||
// Tuple trait for detecting Underscore member
|
||||
template <class Tuple>
|
||||
using has_underscore = has_elem<Tuple, Underscore>;
|
||||
|
||||
template <class Tuple>
|
||||
using all_underscore = all_elem<Tuple, Underscore>;
|
||||
|
||||
template <class Tuple>
|
||||
using has_int1 = has_elem<Tuple, Int<1>>;
|
||||
|
||||
template <class Tuple>
|
||||
using has_int0 = has_elem<Tuple, Int<0>>;
|
||||
|
||||
//
|
||||
// Slice keeps only the elements of Tuple B that are paired with an Underscore
|
||||
//
|
||||
|
||||
template <class A, class B>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
slice(A const& a, B const& b)
|
||||
{
|
||||
if constexpr (is_tuple<A>::value) {
|
||||
static_assert(tuple_size<A>::value == tuple_size<B>::value, "Mismatched Ranks");
|
||||
return filter_tuple(a, b, [](auto const& x, auto const& y) { return slice(x,y); });
|
||||
} else if constexpr (is_underscore<A>::value) {
|
||||
return cute::tuple<B>{b};
|
||||
} else {
|
||||
return cute::tuple<>{};
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Dice keeps only the elements of Tuple B that are paired with an Int
|
||||
//
|
||||
|
||||
template <class A, class B>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
dice(A const& a, B const& b)
|
||||
{
|
||||
if constexpr (is_tuple<A>::value) {
|
||||
static_assert(tuple_size<A>::value == tuple_size<B>::value, "Mismatched Ranks");
|
||||
return filter_tuple(a, b, [](auto const& x, auto const& y) { return dice(x,y); });
|
||||
} else if constexpr (is_underscore<A>::value) {
|
||||
return cute::tuple<>{};
|
||||
} else {
|
||||
return cute::tuple<B>{b};
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
|
||||
CUTE_HOST_DEVICE void print(Underscore const&) {
|
||||
printf("_");
|
||||
}
|
||||
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, Underscore const&) {
|
||||
return os << "_";
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
153
include/cute/util/debug.hpp
Normal file
153
include/cute/util/debug.hpp
Normal file
@ -0,0 +1,153 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief Debugging and logging functionality
|
||||
*/
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
/******************************************************************************
|
||||
* Debug and logging macros
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Formats and prints the given message to stdout
|
||||
*/
|
||||
#if !defined(CUTE_LOG)
|
||||
# if !defined(__CUDA_ARCH__)
|
||||
# define CUTE_LOG(format, ...) printf(format, __VA_ARGS__)
|
||||
# else
|
||||
# define CUTE_LOG(format, ...) \
|
||||
printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \
|
||||
blockIdx.x, blockIdx.y, blockIdx.z, \
|
||||
threadIdx.x, threadIdx.y, threadIdx.z, \
|
||||
__VA_ARGS__);
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Formats and prints the given message to stdout only if DEBUG is defined
|
||||
*/
|
||||
#if !defined(CUTE_LOG_DEBUG)
|
||||
# ifdef DEBUG
|
||||
# define CUTE_LOG_DEBUG(format, ...) CUTE_LOG(format, __VA_ARGS__)
|
||||
# else
|
||||
# define CUTE_LOG_DEBUG(format, ...)
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \brief Perror macro with exit
|
||||
*/
|
||||
#if !defined(CUTE_ERROR_EXIT)
|
||||
# define CUTE_ERROR_EXIT(e) \
|
||||
do { \
|
||||
cudaError_t code = (e); \
|
||||
if (code != cudaSuccess) { \
|
||||
fprintf(stderr, "<%s:%d> %s:\n %s: %s\n", \
|
||||
__FILE__, __LINE__, #e, \
|
||||
cudaGetErrorName(code), cudaGetErrorString(code)); \
|
||||
fflush(stderr); \
|
||||
exit(0); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#if !defined(CUTE_CHECK_LAST)
|
||||
# define CUTE_CHECK_LAST() CUTE_ERROR_EXIT(cudaPeekAtLastError()); CUTE_ERROR_EXIT(cudaDeviceSynchronize())
|
||||
#endif
|
||||
|
||||
#if !defined(CUTE_CHECK_ERROR)
|
||||
# define CUTE_CHECK_ERROR(e) CUTE_ERROR_EXIT(e)
|
||||
#endif
|
||||
|
||||
// A dummy function that uses compilation failure to print a type
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print_type(T&&) {
|
||||
static_assert(sizeof(T) < 0, "Printing type T.");
|
||||
}
|
||||
|
||||
//
|
||||
// Device-specific helpers
|
||||
//
|
||||
// e.g.
|
||||
// if (thread0()) print(...);
|
||||
// if (block0()) print(...);
|
||||
// if (thread(42)) print(...);
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
bool
|
||||
thread(int tid, int bid)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid)
|
||||
&& ( blockIdx.x + blockIdx.y* gridDim.x + blockIdx.z* gridDim.x* gridDim.y == bid);
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
bool
|
||||
thread(int tid)
|
||||
{
|
||||
return thread(tid, 0);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
bool
|
||||
thread0()
|
||||
{
|
||||
return thread(0,0);
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
bool
|
||||
block0()
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return !(blockIdx.x | blockIdx.y | blockIdx.z);
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
140
include/cute/util/print.hpp
Normal file
140
include/cute/util/print.hpp
Normal file
@ -0,0 +1,140 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
//
|
||||
// CUDA compatible print and printf
|
||||
//
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
int
|
||||
num_digits(int x)
|
||||
{
|
||||
return (x < 10 ? 1 :
|
||||
(x < 100 ? 2 :
|
||||
(x < 1000 ? 3 :
|
||||
(x < 10000 ? 4 :
|
||||
(x < 100000 ? 5 :
|
||||
(x < 1000000 ? 6 :
|
||||
(x < 10000000 ? 7 :
|
||||
(x < 100000000 ? 8 :
|
||||
(x < 1000000000 ? 9 :
|
||||
10)))))))));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct format_and_size {
|
||||
using type = T;
|
||||
char const* format;
|
||||
int digits;
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<int>
|
||||
get_format(bool) {
|
||||
return {"%*d", 3};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<int32_t>
|
||||
get_format(int32_t) {
|
||||
return {"%*d", 5};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<uint32_t>
|
||||
get_format(uint32_t) {
|
||||
return {"%*d", 5};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<int64_t>
|
||||
get_format(int64_t) {
|
||||
return {"%*d", 5};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<uint64_t>
|
||||
get_format(uint64_t) {
|
||||
return {"%*d", 5};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<float>
|
||||
get_format(half_t) {
|
||||
return {"%*.2f", 8};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<float>
|
||||
get_format(float) {
|
||||
return {"%*.2e", 10};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
format_and_size<double>
|
||||
get_format(double) {
|
||||
return {"%*.3e", 11};
|
||||
}
|
||||
|
||||
//
|
||||
// print dispatcher
|
||||
//
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(char const& c) {
|
||||
printf("%c", c);
|
||||
}
|
||||
|
||||
template <class T,
|
||||
__CUTE_REQUIRES(std::is_integral<T>::value)>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(T const& a) {
|
||||
printf("%d", int(a));
|
||||
}
|
||||
|
||||
template <class... T>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(char const* format, T const&... t) {
|
||||
printf(format, t...);
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
101
include/cute/util/type_traits.hpp
Normal file
101
include/cute/util/type_traits.hpp
Normal file
@ -0,0 +1,101 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#define __CUTE_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type* = nullptr
|
||||
#define __CUTE_REQUIRES_V(...) typename std::enable_if<decltype((__VA_ARGS__))::value>::type* = nullptr
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
using std::conjunction;
|
||||
using std::conjunction_v;
|
||||
|
||||
using std::disjunction;
|
||||
using std::disjunction_v;
|
||||
|
||||
using std::negation;
|
||||
using std::negation_v;
|
||||
|
||||
using std::void_t;
|
||||
|
||||
// C++20
|
||||
// using std::remove_cvref;
|
||||
template <class T>
|
||||
struct remove_cvref {
|
||||
using type = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
};
|
||||
|
||||
// C++20
|
||||
// using std::remove_cvref_t;
|
||||
template <class T>
|
||||
using remove_cvref_t = typename remove_cvref<T>::type;
|
||||
|
||||
//
|
||||
// is_valid
|
||||
//
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class F, class... Args, class = decltype(std::declval<F&&>()(std::declval<Args&&>()...))>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
is_valid_impl(int) { return std::true_type{}; }
|
||||
|
||||
template <class F, class... Args>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
is_valid_impl(...) { return std::false_type{}; }
|
||||
|
||||
template <class F>
|
||||
struct is_valid_fn {
|
||||
template <class... Args>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
operator()(Args&&...) const { return is_valid_impl<F, Args&&...>(int{}); }
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
template <class F>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
is_valid(F&&) {
|
||||
return detail::is_valid_fn<F&&>{};
|
||||
}
|
||||
|
||||
template <class F, class... Args>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
is_valid(F&&, Args&&...) {
|
||||
return detail::is_valid_impl<F&&, Args&&...>(int{});
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
404
include/cutlass/arch/barrier.h
Normal file
404
include/cutlass/arch/barrier.h
Normal file
@ -0,0 +1,404 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are not permit-
|
||||
* ted.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Barrier Operations on SM90+
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/memory_sm75.h>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
|
||||
namespace cutlass {
|
||||
/// @brief
|
||||
namespace arch {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12)
|
||||
#define CUDA_BARRIER_ENABLED 1
|
||||
#else
|
||||
#define CUDA_BARRIER_ENABLED 0
|
||||
#endif
|
||||
|
||||
class NamedBarrier {
|
||||
|
||||
// Data Members:
|
||||
|
||||
// Range = [1 , NUM_THREADS_PER_CTA]
|
||||
// Range % warp-size (i.e 32) == 0
|
||||
uint32_t const num_threads_;
|
||||
|
||||
// Range : [0, 15]
|
||||
uint32_t const id_;
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
NamedBarrier(uint32_t num_threads, uint32_t id = 0)
|
||||
: num_threads_(num_threads), id_(id) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void arrive_and_wait() const {
|
||||
NamedBarrier::arrive_and_wait(num_threads_, id_);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void arrive() const {
|
||||
NamedBarrier::arrive(num_threads_, id_);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void sync() const {
|
||||
NamedBarrier::arrive_and_wait();
|
||||
}
|
||||
|
||||
// Static variants
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void arrive(uint32_t num_threads, uint32_t barrier_id) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void sync(uint32_t num_threads, uint32_t barrier_id) {
|
||||
NamedBarrier::arrive_and_wait(num_threads, barrier_id);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide AW behaviour.
|
||||
// This is an extension to the Ampere AW barriers
|
||||
// Note : Ampere AW Barriers have a larger max-arrive count (2^30) than Hopper AW Barriers (2^20).
|
||||
struct ClusterBarrier {
|
||||
|
||||
using ValueType = uint64_t;
|
||||
|
||||
protected:
|
||||
// Can never be initializated - can only be aliased to smem
|
||||
ValueType barrier_;
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ClusterBarrier() = delete;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void init(uint32_t arrive_count) const {
|
||||
ClusterBarrier::init(&this->barrier_, arrive_count);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
uint32_t test_wait(uint32_t phase, uint32_t pred=true) const {
|
||||
return ClusterBarrier::test_wait(&this->barrier_, phase, pred);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void wait(uint32_t phase) const {
|
||||
ClusterBarrier::wait(&this->barrier_, phase);
|
||||
}
|
||||
|
||||
// Barrier arrive on local smem
|
||||
CUTLASS_DEVICE
|
||||
void arrive() const {
|
||||
ClusterBarrier::arrive(&this->barrier_);
|
||||
}
|
||||
|
||||
// Remote SMEM arrive with a perdicate (usually done to pick the thread doing the arrive)
|
||||
CUTLASS_DEVICE
|
||||
void arrive(uint32_t cta_id, uint32_t pred = true ) const {
|
||||
ClusterBarrier::arrive(&this->barrier_, cta_id, pred);
|
||||
}
|
||||
|
||||
//
|
||||
// Static Versions
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
static void init(ValueType const* smem_ptr, uint32_t arrive_count) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"mbarrier.init.shared.b64 [%1], %0; \n"
|
||||
"}"
|
||||
:
|
||||
: "r"(arrive_count), "r"(smem_addr));
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Static version of wait - in case we don't want to burn a register
|
||||
CUTLASS_DEVICE
|
||||
static void wait(ValueType const* smem_ptr, uint32_t phase) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
// Arbitrarily large timer value after which try-wait expires and re-tries.
|
||||
uint32_t ticks = 0x989680;
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred P1; \n\t"
|
||||
"LAB_WAIT: \n\t"
|
||||
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
|
||||
"@P1 bra.uni DONE; \n\t"
|
||||
"bra.uni LAB_WAIT; \n\t"
|
||||
"DONE: \n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(smem_addr), "r"(phase), "r"(ticks));
|
||||
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
uint32_t waitComplete;
|
||||
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred P1; \n\t"
|
||||
".reg .pred P2; \n\t"
|
||||
"setp.eq.u32 P2, %3, 1;\n\t"
|
||||
"@P2 mbarrier.test_wait.parity.shared.b64 P1, [%1], %2; \n\t"
|
||||
"selp.b32 %0, 1, 0, P1; \n\t"
|
||||
"}"
|
||||
: "=r"(waitComplete)
|
||||
: "r"(smem_addr), "r"(phase), "r"(pred));
|
||||
|
||||
return waitComplete;
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Static Predicated version of the above - in case we know the address.
|
||||
CUTLASS_DEVICE
|
||||
static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
".reg .b32 remAddr32;\n\t"
|
||||
"setp.eq.u32 p, %2, 1;\n\t"
|
||||
"@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t"
|
||||
"@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(smem_addr), "r"(cta_id), "r"(pred));
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Barrier arrive on local smem
|
||||
CUTLASS_DEVICE
|
||||
static void arrive(ValueType const* smem_ptr) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
uint64_t state = 0;
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"mbarrier.arrive.shared.b64 %1, [%0];\n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(smem_addr), "l"(state));
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void invalidate(ValueType const* smem_ptr) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"mbarrier.ival.shared.b64 [%0]; \n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(smem_addr));
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// SM90 also introduces a new type of cluster-barrier which supports sync.
|
||||
// not just based on Arrive Count, but also transaction count (in bytes)
|
||||
struct ClusterTransactionBarrier : public ClusterBarrier {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ClusterTransactionBarrier() = delete;
|
||||
|
||||
// Performs an arrive operation + bytes reset
|
||||
CUTLASS_DEVICE
|
||||
void arrive_and_reset_bytes(uint32_t transaction_bytes) const {
|
||||
ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes);
|
||||
}
|
||||
|
||||
// Performs an arrive operation + bytes reset
|
||||
CUTLASS_DEVICE
|
||||
void arrive_and_reset_bytes(uint32_t transaction_bytes, uint32_t cta_id) const {
|
||||
ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes , cta_id, true);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void commit(uint32_t transaction_bytes, uint32_t pred = 1) const {
|
||||
uint32_t cta_rank = cute::block_rank_in_cluster();
|
||||
ClusterTransactionBarrier::commit(&this->barrier_, cta_rank, transaction_bytes, pred);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void commit(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const {
|
||||
ClusterTransactionBarrier::commit(&this->barrier_, dst_cta_id, transaction_bytes, pred);
|
||||
}
|
||||
|
||||
//
|
||||
// Static Versions
|
||||
//
|
||||
|
||||
// Performs an arrive operation + bytes reset
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0; \n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(transaction_bytes), "r"(smem_addr));
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Performs an arrive operation + bytes reset for a remote cta_id in a Cluster
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_and_reset_bytes(
|
||||
ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
".reg .b32 remAddr32;\n\t"
|
||||
"setp.eq.u32 p, %2, 1;\n\t"
|
||||
"@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t"
|
||||
"@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes));
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Performs an bytes reset without doing an arrive operation
|
||||
CUTLASS_DEVICE
|
||||
static void reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"mbarrier.expect_tx.shared.b64 [%1], %0; \n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(transaction_bytes), "r"(smem_addr));
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Increments transaction bytes in the barrier
|
||||
CUTLASS_DEVICE
|
||||
static void commit(
|
||||
ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
smem_addr = cute::set_block_rank(smem_addr, dst_cta_id);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.eq.u32 p, %2, 1;\n\t"
|
||||
"@p mbarrier.complete_tx.shared::cluster.relaxed.cluster.b64 [%1], %0;"
|
||||
"}"
|
||||
:
|
||||
: "r"(transaction_bytes), "r"(smem_addr), "r"(pred));
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Helps with visibility of barrier init operations across warps / cta / cluster
|
||||
// Available as a separate function so as to batch inits across barriers and fence once
|
||||
// Note : It must be composed with an appropriate sync instruction with the right scope
|
||||
// to ensure visibility eg. __syncthreads() or a cluster_arrive() + cluster_wait()
|
||||
CUTLASS_DEVICE
|
||||
void fence_barrier_init() {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"fence.mbarrier_init.release.cluster; \n"
|
||||
"}"
|
||||
::);
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Issue a shared memory fence for async operations
|
||||
CUTLASS_DEVICE
|
||||
void fence_view_async_shared() {
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
asm volatile (
|
||||
"{\n\t"
|
||||
"fence.proxy.async.shared::cta; \n"
|
||||
"}"
|
||||
::);
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // end namespace arch
|
||||
} // end namespace cutlass
|
||||
@ -36,6 +36,7 @@
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cute/arch/util.hpp"
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
@ -65,74 +66,13 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
|
||||
#define CUDA_LDMATRIX_SUPPORTED 1
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/*
|
||||
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) && (__CUDACC_VER_MAJOR__ > 10)
|
||||
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED 1
|
||||
#endif
|
||||
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED)
|
||||
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1))
|
||||
#endif
|
||||
|
||||
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED)
|
||||
#define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED
|
||||
#endif
|
||||
*/
|
||||
|
||||
#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
|
||||
extern "C" {
|
||||
//
|
||||
// This NVVM intrinsic is subject to change in future versions of CUDA.
|
||||
// Clients should not call it directly. Rather, they should use the
|
||||
// cutlass::arch::ldsm<>() template.
|
||||
//
|
||||
__device__ uint32_t __nvvm_get_smem_pointer(void *);
|
||||
}
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// CUTLASS helper to get SMEM pointer
|
||||
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
|
||||
|
||||
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
|
||||
// the previous internal intrinsics if they are available.
|
||||
#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
|
||||
//
|
||||
// This NVVM intrinsic converts an address in shared memory to a plain
|
||||
// unsigned integer. This is necessary to pass to shared memory instructions
|
||||
// in inline PTX.
|
||||
//
|
||||
// In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2].
|
||||
//
|
||||
//__device__ size_t __cvta_generic_to_shared(void* ptr);
|
||||
|
||||
/// CUTLASS helper to get SMEM pointer
|
||||
return static_cast<unsigned>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
|
||||
|
||||
return __nvvm_get_smem_pointer(ptr);
|
||||
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
|
||||
uint32_t smem_ptr;
|
||||
|
||||
asm(
|
||||
"{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
|
||||
: "=r"(smem_ptr) : "l"(ptr));
|
||||
|
||||
return smem_ptr;
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(ptr);
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
return 0;
|
||||
|
||||
#endif
|
||||
return cute::cast_smem_ptr_to_uint(ptr);
|
||||
}
|
||||
|
||||
|
||||
/// CUTLASS helper to get SMEM pointer
|
||||
inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) {
|
||||
return cutlass_get_smem_pointer(const_cast<void *>(ptr));
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user